From ec7505b25a2c7bbd2561476f322eae43b5eaabc0 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 4 Mar 2021 19:24:29 +0000 Subject: [PATCH 1/3] [PyTorch] Guarantee data input is the first argument --- python/tvm/relay/frontend/pytorch.py | 13 ++++++++++++- tests/python/frontend/pytorch/test_forward.py | 2 ++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index dcf2f08caeef..e5ad57c6b87a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3216,5 +3216,16 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt # ListConstruct kept original python list. Convert to tuple. ret = _expr.Tuple(ret) - mod["main"] = tvm.relay.Function(_analysis.free_vars(ret), ret) + # Separate data inputs and parameters to make sure data inputs are always in the beginning. + func_args = [] + data_inputs = [] + for arg in _analysis.free_vars(ret): + if arg.name_hint not in tvm_params.keys(): + data_inputs.append(arg) + else: + func_args.append(arg) + func_args = data_inputs + func_args + + mod["main"] = tvm.relay.Function(func_args, ret) + return transform.RemoveUnusedFunctions()(mod), tvm_params diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 54bf2fd49acb..0bf349081c34 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -201,6 +201,8 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + for idx, arg in enumerate(mod["main"].params[: len(input_names)]): + assert arg.name_hint == input_names[idx] compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) with tvm.transform.PassContext(opt_level=3): From 7a71c52cf147c03ac82d01955989b3cd0cb50da8 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 5 Mar 2021 00:20:51 +0000 Subject: [PATCH 2/3] fix --- tests/python/frontend/pytorch/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0bf349081c34..53a60d041d10 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -202,7 +202,7 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) for idx, arg in enumerate(mod["main"].params[: len(input_names)]): - assert arg.name_hint == input_names[idx] + assert arg.name_hint in input_names[idx] compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) with tvm.transform.PassContext(opt_level=3): From d6b11ce611a6f0a22f1c7e1c44281e61cf6fdbc1 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 5 Mar 2021 00:21:16 +0000 Subject: [PATCH 3/3] fix --- tests/python/frontend/pytorch/test_forward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 53a60d041d10..41679bf16c5d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -201,8 +201,8 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) - for idx, arg in enumerate(mod["main"].params[: len(input_names)]): - assert arg.name_hint in input_names[idx] + for arg in mod["main"].params[: len(input_names)]: + assert arg.name_hint in input_names compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) with tvm.transform.PassContext(opt_level=3):