From f00b9e7b9819ea688ff9fd63652704af62a33a6a Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Thu, 10 Nov 2022 03:21:44 -0800 Subject: [PATCH 1/3] Stop randomizing primfunc buffer order --- python/tvm/te/operation.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 5279c46aebc2..846f88d38938 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -398,11 +398,12 @@ def before_split(a: T.handle, b: T.handle) -> None: C = te.extern_primfunc([A, B], func) """ - access_map = { - k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items() - } - in_buffers = [buf for buf, access in access_map.items() if len(access[0])] - out_buffers = [buf for buf, access in access_map.items() if len(access[1])] + + # dt_access_map and primfunc.buffer_map are unordered, so use order from primfunc.params + dt_access_map = tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc) + ordered_buffers = [primfunc.buffer_map[param] for param in primfunc.params] + in_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][0])] + out_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][1])] assert in_buffers, "PrimFunc has no input buffers" assert out_buffers, "PrimFunc has no output buffers" From 754b26b5b5a056ac62b23a283cd92c478d818611 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Thu, 10 Nov 2022 23:23:56 -0800 Subject: [PATCH 2/3] Add regression test --- .../unittest/test_tir_te_extern_primfunc.py | 46 ++++--------------- 1 file changed, 8 insertions(+), 38 deletions(-) diff --git a/tests/python/unittest/test_tir_te_extern_primfunc.py b/tests/python/unittest/test_tir_te_extern_primfunc.py index 26752145620a..7413e37ad7e7 100644 --- a/tests/python/unittest/test_tir_te_extern_primfunc.py +++ b/tests/python/unittest/test_tir_te_extern_primfunc.py @@ -174,11 +174,11 @@ def verify_func_4(module): class TestPrimFuncs: - func, verify = tvm.testing.parameters( - [func_1, verify_func_1], - [func_2, verify_func_2], - [func_3, verify_func_3], - [func_4, verify_func_4], + func, params, verify = tvm.testing.parameters( + [func_1, ("A"), verify_func_1], + [func_2, ("C", "D"), verify_func_2], + [func_3, ("C", "A", "D", "E"), verify_func_3], + [func_4, ("C", "A", "D", "E"), verify_func_4], ) def test_primfunc_call(self, func, verify): @@ -186,11 +186,12 @@ def test_primfunc_call(self, func, verify): func = tvm.build(func, target=target) verify(func) - def test_te_extern_call(self, func, verify): + def test_te_extern_call(self, func, params, verify): ir_mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) prim_func = ir_mod["main"] - input_tensors = create_input_tensors_for_primfunc(prim_func) + buf_name_map = {buf.name: buf for buf in func.buffer_map.values()} + input_tensors = [te.placeholder(buf_name_map[name].shape) for name in params] output = te.extern_primfunc(input_tensors, prim_func) rt_prim_func = te.create_prim_func(tensors_from_extern_op(output, prim_func)) tvm.ir.assert_structural_equal(tvm.lower(prim_func), tvm.lower(rt_prim_func)) @@ -222,36 +223,5 @@ def tensors_from_extern_op(extern, func): return ordered_tensors -def create_input_tensors_for_primfunc(primfunc): - access_map = {k: tuple(v) for k, v in _ffi_arith_api.DomainTouchedAccessMap(primfunc).items()} - in_buffers = [buf for buf, access in access_map.items() if len(access[0])] - out_buffers = [buf for buf, access in access_map.items() if len(access[1])] - assert in_buffers, "PrimFunc has no input buffers" - assert out_buffers, "PrimFunc has no output buffers" - - outputs = [] - inplace = [] - inputs = in_buffers - for obuf in out_buffers: - if obuf in in_buffers: - inplace.append(obuf) - else: - outputs.append(obuf) - - if not outputs: - iobuf = inplace.pop() - inputs.remove(iobuf) - outputs = [iobuf] - - def create_tensors(input_buffers): - tensors = [] - for buf in input_buffers: - t = te.placeholder(buf.shape, dtype=buf.dtype, name=buf.name + "_placeholder") - tensors.append(t) - return tensors - - return create_tensors(inputs) - - if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) From 88c37a2056622bb583848f44bcb2dd1cc23db8f0 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Thu, 10 Nov 2022 23:30:22 -0800 Subject: [PATCH 3/3] Remove unused imports --- tests/python/unittest/test_tir_te_extern_primfunc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_te_extern_primfunc.py b/tests/python/unittest/test_tir_te_extern_primfunc.py index 7413e37ad7e7..a622f77cc737 100644 --- a/tests/python/unittest/test_tir_te_extern_primfunc.py +++ b/tests/python/unittest/test_tir_te_extern_primfunc.py @@ -21,10 +21,8 @@ import tvm import tvm.testing -from tvm import tir, te, TVMError +from tvm import te from tvm.script import tir as T -from tvm.arith import _ffi_api as _ffi_arith_api -from tvm.tir.schedule import _ffi_api as _ffi_schedule_api # TODO(csullivan): Additional tests cases needed: