Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relax/backend/dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
tir_call = self.builder_.call_te(
te_func,
call.args[0],
k=call.attrs.k,
axis=call.attrs.axis,
ret_type=call.attrs.ret_type,
is_ascend=not call.attrs.largest,
Expand Down
282 changes: 58 additions & 224 deletions tests/python/relax/test_backend_dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,84 +38,27 @@ class Before:
def foo(x: R.Tensor((2, 3), "float32", "llvm")):
with R.dataflow():
lv0 = R.cumsum(x, axis=1, dtype="float64", exclusive=False)
gv = R.cumprod(lv0, axis=1, dtype="float64", exclusive=False)
lv1 = R.cumprod(lv0, axis=1, dtype="float64", exclusive=False)
gv = lv1
R.output(gv)
return gv

@I.ir_module
class Expected:
I.module_global_infos({"vdevice": [I.vdevice("llvm", 0)]})
mod = DispatchSortScan()(Before)

@T.prim_func(private=True)
def cumsum(var_A: T.handle, out_buf: T.Buffer((T.int64(2), T.int64(3)), "float64")):
T.func_attr({"tir.noalias": T.bool(True)})
A = T.match_buffer(var_A, (T.int64(2), T.int64(3)), offset_factor=1)
with T.block("cumsum_generic"):
for fused in T.parallel(T.int64(2)):
out_buf[
fused * T.int64(3) // T.int64(3), fused * T.int64(3) % T.int64(3)
] = T.Cast(
"float64",
A[fused * T.int64(3) // T.int64(3), fused * T.int64(3) % T.int64(3)],
)
for _k in range(T.int64(2)):
out_buf[
(fused * T.int64(3) + (_k + T.int64(1))) // T.int64(3),
(fused * T.int64(3) + (_k + T.int64(1))) % T.int64(3),
] = out_buf[
(fused * T.int64(3) + (_k + T.int64(1) - T.int64(1))) // T.int64(3),
(fused * T.int64(3) + (_k + T.int64(1) - T.int64(1))) % T.int64(3),
] + T.Cast(
"float64",
A[
(fused * T.int64(3) + (_k + T.int64(1))) // T.int64(3),
(fused * T.int64(3) + (_k + T.int64(1))) % T.int64(3),
],
)

@T.prim_func(private=True)
def cumprod(var_A: T.handle, out_buf: T.Buffer((T.int64(2), T.int64(3)), "float64")):
T.func_attr({"tir.noalias": T.bool(True)})
A = T.match_buffer(var_A, (T.int64(2), T.int64(3)), "float64", offset_factor=1)
with T.block("cumprod_generic"):
T.reads(A[T.int64(0) : T.int64(2), T.int64(0) : T.int64(3)])
T.writes(out_buf[T.int64(0) : T.int64(2), T.int64(0) : T.int64(3)])
for fused in T.parallel(T.int64(2)):
out_buf[fused * T.int64(3) // T.int64(3), fused * T.int64(3) % T.int64(3)] = A[
fused * T.int64(3) // T.int64(3), fused * T.int64(3) % T.int64(3)
]
for _k in range(T.int64(2)):
out_buf[
(fused * T.int64(3) + (_k + T.int64(1))) // T.int64(3),
(fused * T.int64(3) + (_k + T.int64(1))) % T.int64(3),
] = (
out_buf[
(fused * T.int64(3) + (_k + T.int64(1) - T.int64(1))) // T.int64(3),
(fused * T.int64(3) + (_k + T.int64(1) - T.int64(1))) % T.int64(3),
]
* A[
(fused * T.int64(3) + (_k + T.int64(1))) // T.int64(3),
(fused * T.int64(3) + (_k + T.int64(1))) % T.int64(3),
]
)
vdevices = [I.vdevice("llvm", 0)]
x = relax.Var("x", R.Tensor((2, 3), "float32", vdevices[0]))
bb = relax.BlockBuilder()

@R.function
def foo(
x: R.Tensor((2, 3), dtype="float32", vdevice="llvm")
) -> R.Tensor((2, 3), dtype="float64", vdevice="llvm"):
cls = Expected
with R.dataflow():
lv0 = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((2, 3), "float64", "llvm"))
gv = R.call_tir(
cls.cumprod,
(lv0,),
out_sinfo=R.Tensor((2, 3), dtype="float64", vdevice="llvm"),
)
R.output(gv)
return gv
with bb.function("foo", (x,), {"global_symbol": "foo"}):
with bb.dataflow():
lv0 = bb.emit_te(topi.cumsum, x, axis=1, dtype="float64", exclusive=False)
out = bb.emit_te(topi.cumprod, lv0, axis=1, dtype="float64", exclusive=False)
out = bb.emit_output(out)
bb.emit_func_output(out)
expected_mod = bb.finalize()
expected_mod.update_global_info("vdevice", vdevices)

mod = DispatchSortScan()(Before)
assert_structural_equal(mod, Expected)
assert_structural_equal(mod, expected_mod)


def test_dispatch_scanop_cuda():
Expand Down Expand Up @@ -172,60 +115,26 @@ class Before:
def foo(x: R.Tensor(("m", 3), "float32", "llvm")):
m = T.int64()
with R.dataflow():
gv = R.sort(x, axis=1, descending=False)
lv = R.sort(x, axis=1, descending=False)
gv = lv
R.output(gv)
return gv

@I.ir_module
class Expected:
I.module_global_infos({"vdevice": [I.vdevice("llvm", 0)]})

@T.prim_func(private=True)
def sort(var_A: T.handle, var_sort_cpu: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
m = T.int64()
data_buf = T.match_buffer(var_A, (m, T.int64(3)), align=8)
out_buf = T.match_buffer(var_sort_cpu, (m, T.int64(3)), align=8)
with T.block("sort_cpu"):
T.reads()
T.writes()
T.call_packed(
"tvm.contrib.sort.sort",
T.tvm_stack_make_array(
data_buf.data,
T.tvm_stack_make_shape(m, T.int64(3)),
0,
2,
T.float32(0),
T.int64(0),
),
T.tvm_stack_make_array(
out_buf.data,
T.tvm_stack_make_shape(m, T.int64(3)),
0,
2,
T.float32(0),
T.int64(0),
),
1,
T.bool(True),
)
vdevices = [I.vdevice("llvm", 0)]
m = tir.Var("m", "int64")
x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0]))
bb = relax.BlockBuilder()

@R.function
def foo(
x: R.Tensor(("m", 3), dtype="float32", vdevice="llvm")
) -> R.Tensor(("m", 3), dtype="float32", vdevice="llvm"):
m = T.int64()
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.sort, (x,), out_sinfo=R.Tensor((m, 3), dtype="float32", vdevice="llvm")
)
R.output(gv)
return gv
with bb.function("foo", (x,), {"global_symbol": "foo"}):
with bb.dataflow():
out = bb.emit_te(topi.sort, x, axis=1, is_ascend=True)
out = bb.emit_output(out)
bb.emit_func_output(out)
expected_mod = bb.finalize()
expected_mod.update_global_info("vdevice", vdevices)

mod = DispatchSortScan()(Before)
assert_structural_equal(mod, Expected)
assert_structural_equal(mod, expected_mod)


def test_dispatch_sort_cuda():
Expand Down Expand Up @@ -295,55 +204,26 @@ class Before:
def foo(x: R.Tensor(("m", 3), "float32", "llvm")):
m = T.int64()
with R.dataflow():
gv = R.argsort(x, axis=1, descending=False)
lv = R.argsort(x, axis=1, descending=False, dtype="int32")
gv = lv
R.output(gv)
return gv

@I.ir_module
class Expected:
I.module_global_infos({"vdevice": [I.vdevice("llvm", 0)]})

@T.prim_func(private=True)
def argsort(var_A: T.handle, var_argsort_cpu: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
m = T.int64()
data_buf = T.match_buffer(var_A, (m, T.int64(3)), align=8)
out_buf = T.match_buffer(var_argsort_cpu, (m, T.int64(3)), "int32", align=8)
with T.block("argsort_cpu"):
T.reads(data_buf[T.int64(0) : m, T.int64(0) : T.int64(3)])
T.writes(out_buf[T.int64(0) : m, T.int64(0) : T.int64(3)])
T.call_packed(
"tvm.contrib.sort.argsort",
T.tvm_stack_make_array(
data_buf.data,
T.tvm_stack_make_shape(m, T.int64(3)),
0,
2,
T.float32(0),
T.int64(0),
),
T.tvm_stack_make_array(
out_buf.data, T.tvm_stack_make_shape(m, T.int64(3)), 0, 2, 0, T.int64(0)
),
1,
T.bool(True),
)
vdevices = [I.vdevice("llvm", 0)]
m = tir.Var("m", "int64")
x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0]))
bb = relax.BlockBuilder()

@R.function
def foo(
x: R.Tensor(("m", 3), dtype="float32", vdevice="llvm")
) -> R.Tensor(("m", 3), dtype="int32", vdevice="llvm"):
m = T.int64()
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.argsort, (x,), out_sinfo=R.Tensor((m, 3), dtype="int32", vdevice="llvm")
)
R.output(gv)
return gv
with bb.function("foo", (x,), {"global_symbol": "foo"}):
with bb.dataflow():
out = bb.emit_te(topi.argsort, x, axis=1, is_ascend=True, dtype="int32")
out = bb.emit_output(out)
bb.emit_func_output(out)
expected_mod = bb.finalize()
expected_mod.update_global_info("vdevice", vdevices)

mod = DispatchSortScan()(Before)
assert_structural_equal(mod, Expected)
assert_structural_equal(mod, expected_mod)


def test_dispatch_argsort_cuda():
Expand Down Expand Up @@ -410,71 +290,26 @@ class Before:
def foo(x: R.Tensor(("m", 3), "float32", "llvm")):
m = T.int64()
with R.dataflow():
gv = R.topk(x, k=2, axis=1, largest=True)
lv = R.topk(x, k=2, axis=1, largest=True)
gv = lv
R.output(gv)
return gv

@I.ir_module
class Expected:
I.module_global_infos({"vdevice": [I.vdevice("llvm", 0)]})

@T.prim_func(private=True)
def topk(var_A: T.handle, var_topk_cpu_v0: T.handle, var_topk_cpu_v1: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
m = T.int64()
data_buf = T.match_buffer(var_A, (m, T.int64(3)), align=8)
value_buf = T.match_buffer(var_topk_cpu_v0, (m, T.int64(1)), align=8)
indices_buf = T.match_buffer(var_topk_cpu_v1, (m, T.int64(1)), "int32", align=8)
with T.block("topk_cpu"):
T.reads(data_buf[T.int64(0) : m, T.int64(0) : T.int64(3)])
T.writes(
value_buf[T.int64(0) : m, T.int64(0)], indices_buf[T.int64(0) : m, T.int64(0)]
)
T.call_packed(
"tvm.contrib.sort.topk",
T.tvm_stack_make_array(
data_buf.data,
T.tvm_stack_make_shape(m, T.int64(3)),
0,
2,
T.float32(0),
T.int64(0),
),
T.tvm_stack_make_array(
value_buf.data, T.tvm_stack_make_shape(m, 1), 0, 2, T.float32(0), T.int64(0)
),
T.tvm_stack_make_array(
indices_buf.data, T.tvm_stack_make_shape(m, 1), 0, 2, 0, T.int64(0)
),
1,
1,
"both",
T.bool(False),
)
vdevices = [I.vdevice("llvm", 0)]
m = tir.Var("m", "int64")
x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0]))
bb = relax.BlockBuilder()

@R.function
def foo(
x: R.Tensor(("m", 3), dtype="float32", vdevice="llvm")
) -> R.Tuple(
R.Tensor(("m", 1), dtype="float32", vdevice="llvm"),
R.Tensor(("m", 1), dtype="int32", vdevice="llvm"),
):
m = T.int64()
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.topk,
(x,),
out_sinfo=[
R.Tensor((m, 1), dtype="float32", vdevice="llvm"),
R.Tensor((m, 1), dtype="int32", vdevice="llvm"),
],
)
R.output(gv)
return gv
with bb.function("foo", (x,), {"global_symbol": "foo"}):
with bb.dataflow():
out = bb.emit_te(topi.topk, x, k=2, axis=1, is_ascend=False, dtype="int32")
out = bb.emit_output(out)
bb.emit_func_output(out)
expected_mod = bb.finalize()
expected_mod.update_global_info("vdevice", vdevices)

mod = DispatchSortScan()(Before)
assert_structural_equal(mod, Expected)
assert_structural_equal(mod, expected_mod)


def test_dispatch_topk_cuda():
Expand All @@ -494,12 +329,11 @@ def foo(x: R.Tensor((2, 3), "float32", "cuda")):

vdevices = [I.vdevice("cuda", 0)]
x = relax.Var("x", R.Tensor((2, 3), "float32", vdevices[0]))
y = relax.Var("y", R.Tensor((2, 3), "float32"))
bb = relax.BlockBuilder()
with target:
with bb.function("foo", (x,), {"global_symbol": "foo"}):
with bb.dataflow():
out = bb.emit_te(topi.cuda.topk, x, axis=1, is_ascend=False, dtype="int32")
out = bb.emit_te(topi.cuda.topk, x, k=2, axis=1, is_ascend=False, dtype="int32")
out = bb.emit_output(out)
bb.emit_func_output(out)
expected_mod = bb.finalize()
Expand Down