diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index dcf2f08caeef..e5ad57c6b87a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3216,5 +3216,16 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt # ListConstruct kept original python list. Convert to tuple. ret = _expr.Tuple(ret) - mod["main"] = tvm.relay.Function(_analysis.free_vars(ret), ret) + # Separate data inputs and parameters to make sure data inputs are always in the beginning. + func_args = [] + data_inputs = [] + for arg in _analysis.free_vars(ret): + if arg.name_hint not in tvm_params.keys(): + data_inputs.append(arg) + else: + func_args.append(arg) + func_args = data_inputs + func_args + + mod["main"] = tvm.relay.Function(func_args, ret) + return transform.RemoveUnusedFunctions()(mod), tvm_params diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 54bf2fd49acb..41679bf16c5d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -201,6 +201,8 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + for arg in mod["main"].params[: len(input_names)]: + assert arg.name_hint in input_names compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) with tvm.transform.PassContext(opt_level=3):