diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index bb3f57ce96ae..a223b64ad026 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -116,6 +116,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: tir_call = self.builder_.call_te( te_func, call.args[0], + k=call.attrs.k, axis=call.attrs.axis, ret_type=call.attrs.ret_type, is_ascend=not call.attrs.largest, diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py b/tests/python/relax/test_backend_dispatch_sort_scan.py index 8921372f2f6c..4d08189ac86f 100644 --- a/tests/python/relax/test_backend_dispatch_sort_scan.py +++ b/tests/python/relax/test_backend_dispatch_sort_scan.py @@ -38,84 +38,27 @@ class Before: def foo(x: R.Tensor((2, 3), "float32", "llvm")): with R.dataflow(): lv0 = R.cumsum(x, axis=1, dtype="float64", exclusive=False) - gv = R.cumprod(lv0, axis=1, dtype="float64", exclusive=False) + lv1 = R.cumprod(lv0, axis=1, dtype="float64", exclusive=False) + gv = lv1 R.output(gv) return gv - @I.ir_module - class Expected: - I.module_global_infos({"vdevice": [I.vdevice("llvm", 0)]}) + mod = DispatchSortScan()(Before) - @T.prim_func(private=True) - def cumsum(var_A: T.handle, out_buf: T.Buffer((T.int64(2), T.int64(3)), "float64")): - T.func_attr({"tir.noalias": T.bool(True)}) - A = T.match_buffer(var_A, (T.int64(2), T.int64(3)), offset_factor=1) - with T.block("cumsum_generic"): - for fused in T.parallel(T.int64(2)): - out_buf[ - fused * T.int64(3) // T.int64(3), fused * T.int64(3) % T.int64(3) - ] = T.Cast( - "float64", - A[fused * T.int64(3) // T.int64(3), fused * T.int64(3) % T.int64(3)], - ) - for _k in range(T.int64(2)): - out_buf[ - (fused * T.int64(3) + (_k + T.int64(1))) // T.int64(3), - (fused * T.int64(3) + (_k + T.int64(1))) % T.int64(3), - ] = out_buf[ - (fused * T.int64(3) + (_k + T.int64(1) - T.int64(1))) // T.int64(3), - (fused * T.int64(3) + (_k + T.int64(1) - T.int64(1))) % T.int64(3), - ] + T.Cast( - "float64", - A[ - (fused * T.int64(3) + (_k + T.int64(1))) // T.int64(3), - (fused * T.int64(3) + (_k + T.int64(1))) % T.int64(3), - ], - ) - - @T.prim_func(private=True) - def cumprod(var_A: T.handle, out_buf: T.Buffer((T.int64(2), T.int64(3)), "float64")): - T.func_attr({"tir.noalias": T.bool(True)}) - A = T.match_buffer(var_A, (T.int64(2), T.int64(3)), "float64", offset_factor=1) - with T.block("cumprod_generic"): - T.reads(A[T.int64(0) : T.int64(2), T.int64(0) : T.int64(3)]) - T.writes(out_buf[T.int64(0) : T.int64(2), T.int64(0) : T.int64(3)]) - for fused in T.parallel(T.int64(2)): - out_buf[fused * T.int64(3) // T.int64(3), fused * T.int64(3) % T.int64(3)] = A[ - fused * T.int64(3) // T.int64(3), fused * T.int64(3) % T.int64(3) - ] - for _k in range(T.int64(2)): - out_buf[ - (fused * T.int64(3) + (_k + T.int64(1))) // T.int64(3), - (fused * T.int64(3) + (_k + T.int64(1))) % T.int64(3), - ] = ( - out_buf[ - (fused * T.int64(3) + (_k + T.int64(1) - T.int64(1))) // T.int64(3), - (fused * T.int64(3) + (_k + T.int64(1) - T.int64(1))) % T.int64(3), - ] - * A[ - (fused * T.int64(3) + (_k + T.int64(1))) // T.int64(3), - (fused * T.int64(3) + (_k + T.int64(1))) % T.int64(3), - ] - ) + vdevices = [I.vdevice("llvm", 0)] + x = relax.Var("x", R.Tensor((2, 3), "float32", vdevices[0])) + bb = relax.BlockBuilder() - @R.function - def foo( - x: R.Tensor((2, 3), dtype="float32", vdevice="llvm") - ) -> R.Tensor((2, 3), dtype="float64", vdevice="llvm"): - cls = Expected - with R.dataflow(): - lv0 = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((2, 3), "float64", "llvm")) - gv = R.call_tir( - cls.cumprod, - (lv0,), - out_sinfo=R.Tensor((2, 3), dtype="float64", vdevice="llvm"), - ) - R.output(gv) - return gv + with bb.function("foo", (x,), {"global_symbol": "foo"}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.cumsum, x, axis=1, dtype="float64", exclusive=False) + out = bb.emit_te(topi.cumprod, lv0, axis=1, dtype="float64", exclusive=False) + out = bb.emit_output(out) + bb.emit_func_output(out) + expected_mod = bb.finalize() + expected_mod.update_global_info("vdevice", vdevices) - mod = DispatchSortScan()(Before) - assert_structural_equal(mod, Expected) + assert_structural_equal(mod, expected_mod) def test_dispatch_scanop_cuda(): @@ -172,60 +115,26 @@ class Before: def foo(x: R.Tensor(("m", 3), "float32", "llvm")): m = T.int64() with R.dataflow(): - gv = R.sort(x, axis=1, descending=False) + lv = R.sort(x, axis=1, descending=False) + gv = lv R.output(gv) return gv - @I.ir_module - class Expected: - I.module_global_infos({"vdevice": [I.vdevice("llvm", 0)]}) - - @T.prim_func(private=True) - def sort(var_A: T.handle, var_sort_cpu: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - m = T.int64() - data_buf = T.match_buffer(var_A, (m, T.int64(3)), align=8) - out_buf = T.match_buffer(var_sort_cpu, (m, T.int64(3)), align=8) - with T.block("sort_cpu"): - T.reads() - T.writes() - T.call_packed( - "tvm.contrib.sort.sort", - T.tvm_stack_make_array( - data_buf.data, - T.tvm_stack_make_shape(m, T.int64(3)), - 0, - 2, - T.float32(0), - T.int64(0), - ), - T.tvm_stack_make_array( - out_buf.data, - T.tvm_stack_make_shape(m, T.int64(3)), - 0, - 2, - T.float32(0), - T.int64(0), - ), - 1, - T.bool(True), - ) + vdevices = [I.vdevice("llvm", 0)] + m = tir.Var("m", "int64") + x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0])) + bb = relax.BlockBuilder() - @R.function - def foo( - x: R.Tensor(("m", 3), dtype="float32", vdevice="llvm") - ) -> R.Tensor(("m", 3), dtype="float32", vdevice="llvm"): - m = T.int64() - cls = Expected - with R.dataflow(): - gv = R.call_tir( - cls.sort, (x,), out_sinfo=R.Tensor((m, 3), dtype="float32", vdevice="llvm") - ) - R.output(gv) - return gv + with bb.function("foo", (x,), {"global_symbol": "foo"}): + with bb.dataflow(): + out = bb.emit_te(topi.sort, x, axis=1, is_ascend=True) + out = bb.emit_output(out) + bb.emit_func_output(out) + expected_mod = bb.finalize() + expected_mod.update_global_info("vdevice", vdevices) mod = DispatchSortScan()(Before) - assert_structural_equal(mod, Expected) + assert_structural_equal(mod, expected_mod) def test_dispatch_sort_cuda(): @@ -295,55 +204,26 @@ class Before: def foo(x: R.Tensor(("m", 3), "float32", "llvm")): m = T.int64() with R.dataflow(): - gv = R.argsort(x, axis=1, descending=False) + lv = R.argsort(x, axis=1, descending=False, dtype="int32") + gv = lv R.output(gv) return gv - @I.ir_module - class Expected: - I.module_global_infos({"vdevice": [I.vdevice("llvm", 0)]}) - - @T.prim_func(private=True) - def argsort(var_A: T.handle, var_argsort_cpu: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - m = T.int64() - data_buf = T.match_buffer(var_A, (m, T.int64(3)), align=8) - out_buf = T.match_buffer(var_argsort_cpu, (m, T.int64(3)), "int32", align=8) - with T.block("argsort_cpu"): - T.reads(data_buf[T.int64(0) : m, T.int64(0) : T.int64(3)]) - T.writes(out_buf[T.int64(0) : m, T.int64(0) : T.int64(3)]) - T.call_packed( - "tvm.contrib.sort.argsort", - T.tvm_stack_make_array( - data_buf.data, - T.tvm_stack_make_shape(m, T.int64(3)), - 0, - 2, - T.float32(0), - T.int64(0), - ), - T.tvm_stack_make_array( - out_buf.data, T.tvm_stack_make_shape(m, T.int64(3)), 0, 2, 0, T.int64(0) - ), - 1, - T.bool(True), - ) + vdevices = [I.vdevice("llvm", 0)] + m = tir.Var("m", "int64") + x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0])) + bb = relax.BlockBuilder() - @R.function - def foo( - x: R.Tensor(("m", 3), dtype="float32", vdevice="llvm") - ) -> R.Tensor(("m", 3), dtype="int32", vdevice="llvm"): - m = T.int64() - cls = Expected - with R.dataflow(): - gv = R.call_tir( - cls.argsort, (x,), out_sinfo=R.Tensor((m, 3), dtype="int32", vdevice="llvm") - ) - R.output(gv) - return gv + with bb.function("foo", (x,), {"global_symbol": "foo"}): + with bb.dataflow(): + out = bb.emit_te(topi.argsort, x, axis=1, is_ascend=True, dtype="int32") + out = bb.emit_output(out) + bb.emit_func_output(out) + expected_mod = bb.finalize() + expected_mod.update_global_info("vdevice", vdevices) mod = DispatchSortScan()(Before) - assert_structural_equal(mod, Expected) + assert_structural_equal(mod, expected_mod) def test_dispatch_argsort_cuda(): @@ -410,71 +290,26 @@ class Before: def foo(x: R.Tensor(("m", 3), "float32", "llvm")): m = T.int64() with R.dataflow(): - gv = R.topk(x, k=2, axis=1, largest=True) + lv = R.topk(x, k=2, axis=1, largest=True) + gv = lv R.output(gv) return gv - @I.ir_module - class Expected: - I.module_global_infos({"vdevice": [I.vdevice("llvm", 0)]}) - - @T.prim_func(private=True) - def topk(var_A: T.handle, var_topk_cpu_v0: T.handle, var_topk_cpu_v1: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - m = T.int64() - data_buf = T.match_buffer(var_A, (m, T.int64(3)), align=8) - value_buf = T.match_buffer(var_topk_cpu_v0, (m, T.int64(1)), align=8) - indices_buf = T.match_buffer(var_topk_cpu_v1, (m, T.int64(1)), "int32", align=8) - with T.block("topk_cpu"): - T.reads(data_buf[T.int64(0) : m, T.int64(0) : T.int64(3)]) - T.writes( - value_buf[T.int64(0) : m, T.int64(0)], indices_buf[T.int64(0) : m, T.int64(0)] - ) - T.call_packed( - "tvm.contrib.sort.topk", - T.tvm_stack_make_array( - data_buf.data, - T.tvm_stack_make_shape(m, T.int64(3)), - 0, - 2, - T.float32(0), - T.int64(0), - ), - T.tvm_stack_make_array( - value_buf.data, T.tvm_stack_make_shape(m, 1), 0, 2, T.float32(0), T.int64(0) - ), - T.tvm_stack_make_array( - indices_buf.data, T.tvm_stack_make_shape(m, 1), 0, 2, 0, T.int64(0) - ), - 1, - 1, - "both", - T.bool(False), - ) + vdevices = [I.vdevice("llvm", 0)] + m = tir.Var("m", "int64") + x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0])) + bb = relax.BlockBuilder() - @R.function - def foo( - x: R.Tensor(("m", 3), dtype="float32", vdevice="llvm") - ) -> R.Tuple( - R.Tensor(("m", 1), dtype="float32", vdevice="llvm"), - R.Tensor(("m", 1), dtype="int32", vdevice="llvm"), - ): - m = T.int64() - cls = Expected - with R.dataflow(): - gv = R.call_tir( - cls.topk, - (x,), - out_sinfo=[ - R.Tensor((m, 1), dtype="float32", vdevice="llvm"), - R.Tensor((m, 1), dtype="int32", vdevice="llvm"), - ], - ) - R.output(gv) - return gv + with bb.function("foo", (x,), {"global_symbol": "foo"}): + with bb.dataflow(): + out = bb.emit_te(topi.topk, x, k=2, axis=1, is_ascend=False, dtype="int32") + out = bb.emit_output(out) + bb.emit_func_output(out) + expected_mod = bb.finalize() + expected_mod.update_global_info("vdevice", vdevices) mod = DispatchSortScan()(Before) - assert_structural_equal(mod, Expected) + assert_structural_equal(mod, expected_mod) def test_dispatch_topk_cuda(): @@ -494,12 +329,11 @@ def foo(x: R.Tensor((2, 3), "float32", "cuda")): vdevices = [I.vdevice("cuda", 0)] x = relax.Var("x", R.Tensor((2, 3), "float32", vdevices[0])) - y = relax.Var("y", R.Tensor((2, 3), "float32")) bb = relax.BlockBuilder() with target: with bb.function("foo", (x,), {"global_symbol": "foo"}): with bb.dataflow(): - out = bb.emit_te(topi.cuda.topk, x, axis=1, is_ascend=False, dtype="int32") + out = bb.emit_te(topi.cuda.topk, x, k=2, axis=1, is_ascend=False, dtype="int32") out = bb.emit_output(out) bb.emit_func_output(out) expected_mod = bb.finalize()