Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,12 @@ def before_split(a: T.handle, b: T.handle) -> None:

C = te.extern_primfunc([A, B], func)
"""
access_map = {
k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items()
}
in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
out_buffers = [buf for buf, access in access_map.items() if len(access[1])]

# dt_access_map and primfunc.buffer_map are unordered, so use order from primfunc.params
dt_access_map = tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc)
ordered_buffers = [primfunc.buffer_map[param] for param in primfunc.params]
in_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][0])]
out_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][1])]
assert in_buffers, "PrimFunc has no input buffers"
assert out_buffers, "PrimFunc has no output buffers"

Expand Down
50 changes: 9 additions & 41 deletions tests/python/unittest/test_tir_te_extern_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@

import tvm
import tvm.testing
from tvm import tir, te, TVMError
from tvm import te
from tvm.script import tir as T
from tvm.arith import _ffi_api as _ffi_arith_api
from tvm.tir.schedule import _ffi_api as _ffi_schedule_api


# TODO(csullivan): Additional tests cases needed:
Expand Down Expand Up @@ -174,23 +172,24 @@ def verify_func_4(module):


class TestPrimFuncs:
func, verify = tvm.testing.parameters(
[func_1, verify_func_1],
[func_2, verify_func_2],
[func_3, verify_func_3],
[func_4, verify_func_4],
func, params, verify = tvm.testing.parameters(
[func_1, ("A"), verify_func_1],
[func_2, ("C", "D"), verify_func_2],
[func_3, ("C", "A", "D", "E"), verify_func_3],
[func_4, ("C", "A", "D", "E"), verify_func_4],
)

def test_primfunc_call(self, func, verify):
target = tvm.target.Target("llvm")
func = tvm.build(func, target=target)
verify(func)

def test_te_extern_call(self, func, verify):
def test_te_extern_call(self, func, params, verify):
ir_mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
prim_func = ir_mod["main"]

input_tensors = create_input_tensors_for_primfunc(prim_func)
buf_name_map = {buf.name: buf for buf in func.buffer_map.values()}
input_tensors = [te.placeholder(buf_name_map[name].shape) for name in params]
output = te.extern_primfunc(input_tensors, prim_func)
rt_prim_func = te.create_prim_func(tensors_from_extern_op(output, prim_func))
tvm.ir.assert_structural_equal(tvm.lower(prim_func), tvm.lower(rt_prim_func))
Expand Down Expand Up @@ -222,36 +221,5 @@ def tensors_from_extern_op(extern, func):
return ordered_tensors


def create_input_tensors_for_primfunc(primfunc):
access_map = {k: tuple(v) for k, v in _ffi_arith_api.DomainTouchedAccessMap(primfunc).items()}
in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
out_buffers = [buf for buf, access in access_map.items() if len(access[1])]
assert in_buffers, "PrimFunc has no input buffers"
assert out_buffers, "PrimFunc has no output buffers"

outputs = []
inplace = []
inputs = in_buffers
for obuf in out_buffers:
if obuf in in_buffers:
inplace.append(obuf)
else:
outputs.append(obuf)

if not outputs:
iobuf = inplace.pop()
inputs.remove(iobuf)
outputs = [iobuf]

def create_tensors(input_buffers):
tensors = []
for buf in input_buffers:
t = te.placeholder(buf.shape, dtype=buf.dtype, name=buf.name + "_placeholder")
tensors.append(t)
return tensors

return create_tensors(inputs)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))