From f1aed5348a470f20568ee215ed585a50172e28a5 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 11 Apr 2025 10:40:03 +0000 Subject: [PATCH 1/4] enhance missing func types finding in exported program and fx graph frontend --- .../relax/frontend/torch/exported_program_translator.py | 9 ++++++--- python/tvm/relax/frontend/torch/fx_translator.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 875ec3b83ea8..ca81a028ff29 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -511,6 +511,12 @@ def from_exported_program( ): output = None with self.block_builder.dataflow(): + + # Find all the missing function types + missing_func_types = list({node.target.__name__ for node in nodes + if node.op == "call_function" and node.target.__name__ not in self.convert_map}) + assert not missing_func_types, f"Unsupported function types {missing_func_types}" + # Translate the model. for node in nodes: if node.op == "placeholder": @@ -537,9 +543,6 @@ def from_exported_program( self.env[node] = getattr(exported_program.graph_module, node.target) elif node.op == "call_function": func_name = node.target.__name__ - assert ( - func_name in self.convert_map - ), f"Unsupported function type {func_name}" self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a5b50a7d1dce..8d7b18285957 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -884,6 +884,12 @@ def from_fx( with self.block_builder.function(name=func_name, params=inputs.copy(), attrs=func_attrs): output = None with self.block_builder.dataflow(): + + # Find all the missing function types + missing_func_types = list({node.target.__name__ for node in graph.nodes + if node.op == "call_function" and node.target.__name__ not in self.convert_map}) + assert not missing_func_types, f"Unsupported function types {missing_func_types}" + # Translate model parameters. for _, param in model.named_parameters(): shape = param.data.shape @@ -929,9 +935,6 @@ def from_fx( self.env[node] = self.convert_map[type(module)](node) elif node.op == "call_function": func_name = node.target.__name__ - assert ( - func_name in self.convert_map - ), f"Unsupported function type {func_name}" if func_name in custom_ops: self.env[node] = self.convert_map[func_name](node, self) else: From bd55d155f71720b6b3995d0b9036eea76d8684c5 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 11 Apr 2025 10:40:42 +0000 Subject: [PATCH 2/4] fix trailing space issue --- python/tvm/relax/frontend/torch/exported_program_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index ca81a028ff29..698c40840539 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -513,7 +513,7 @@ def from_exported_program( with self.block_builder.dataflow(): # Find all the missing function types - missing_func_types = list({node.target.__name__ for node in nodes + missing_func_types = list({node.target.__name__ for node in nodes if node.op == "call_function" and node.target.__name__ not in self.convert_map}) assert not missing_func_types, f"Unsupported function types {missing_func_types}" From d4476b542876d566cf4f100a2111d5725697c219 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 11 Apr 2025 11:04:24 +0000 Subject: [PATCH 3/4] fix lint issues by formatting the code --- .../frontend/torch/exported_program_translator.py | 10 ++++++++-- python/tvm/relax/frontend/torch/fx_translator.py | 12 +++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 698c40840539..be17001fd034 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -513,8 +513,14 @@ def from_exported_program( with self.block_builder.dataflow(): # Find all the missing function types - missing_func_types = list({node.target.__name__ for node in nodes - if node.op == "call_function" and node.target.__name__ not in self.convert_map}) + missing_func_types = list( + { + node.target.__name__ + for node in nodes + if node.op == "call_function" + and node.target.__name__ not in self.convert_map + } + ) assert not missing_func_types, f"Unsupported function types {missing_func_types}" # Translate the model. diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8d7b18285957..5239b6e938c2 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -885,9 +885,15 @@ def from_fx( output = None with self.block_builder.dataflow(): - # Find all the missing function types - missing_func_types = list({node.target.__name__ for node in graph.nodes - if node.op == "call_function" and node.target.__name__ not in self.convert_map}) + # Find all the missing function types + missing_func_types = list( + { + node.target.__name__ + for node in nodes + if node.op == "call_function" + and node.target.__name__ not in self.convert_map + } + ) assert not missing_func_types, f"Unsupported function types {missing_func_types}" # Translate model parameters. From 7f300b37f15a30caf91f018cad11f11b233e0433 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Sat, 12 Apr 2025 13:38:47 +0000 Subject: [PATCH 4/4] fix name error in fx frontend --- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5239b6e938c2..f6dd235d5a23 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -889,7 +889,7 @@ def from_fx( missing_func_types = list( { node.target.__name__ - for node in nodes + for node in graph.nodes if node.op == "call_function" and node.target.__name__ not in self.convert_map }