From b73aaaeaaf76fe50f6012294b97b088f16cbd621 Mon Sep 17 00:00:00 2001 From: Navya Mehta Date: Mon, 27 Nov 2023 14:17:59 -0500 Subject: [PATCH 01/10] Preserve Pytorch Span Names --- python/tvm/relay/frontend/pytorch.py | 83 ++++++++++++++++++++++------ 1 file changed, 66 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9374a2491280..660623b83a44 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -21,6 +21,7 @@ """PT: PyTorch frontend.""" import functools import itertools +from typing import ABC import math import re import sys @@ -137,7 +138,7 @@ def _is_int_seq(seq): class PyTorchOpConverter: """A helper class for holding PyTorch op converters.""" - def __init__(self, prelude, default_dtype, use_parser_friendly_name=False): + def __init__(self, prelude, default_dtype, use_parser_friendly_name=False, preserve_pytorch_scopes=False): self.prelude = prelude self.default_dtype = default_dtype self.create_convert_map() @@ -146,6 +147,7 @@ def __init__(self, prelude, default_dtype, use_parser_friendly_name=False): self.op_type_dict = {} # map from op type to its presenting order self.current_op = [] # stack for recording current processing op self.use_parser_friendly_name = use_parser_friendly_name + self.preserve_pytorch_scopes = preserve_pytorch_scopes # this incrementally infers the type, see the comments on the type visitor # above. @@ -4204,7 +4206,7 @@ def report_missing_conversion(self, op_names): def convert_block(self, block, outputs): """Translate Torch "Block", used for prim::If and prim::Loop""" ops = _get_operator_nodes( - block.nodes(), self.source_map, self.op_type_dict, self.use_parser_friendly_name + block.nodes(), self.source_map, self.op_type_dict, self.use_parser_friendly_name, self.preserve_pytorch_scopes ) ret_names = _get_input_names(block.returnNode()) return self.convert_operators(ops, outputs, ret_names) @@ -4771,25 +4773,67 @@ def _get_constant(node): return None -def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name): - """Rewrite debug name of node outputs with its operator type""" - - def _get_source_name(op_type): +class NodeNamer(ABC): + def __init__(self, op_type_dict): + self._op_type_dict = op_type_dict + + def increment_op_type_idx(node): + op_type = node.kind() op_idx = 0 if op_type in op_type_dict: op_idx = op_type_dict[op_type] + 1 op_type_dict[op_type] = op_idx - return "_".join([op_type, str(op_idx)]) + return op_idx + + def get_node_source_name(self, node): + raise NotImplementedError() - # get source name of operator and rename all of its outputs + def get_node_output_name(self, node, node_source_name, index): + raise NotImplementedError() + + +class DefaultNodeKindNamer(NodeNamer): + """ # e.g. node.kind(): aten::adaptive_max_pool2d # node_src_name -> aten::adaptive_max_pool2d_x # output_1 -> aten::adaptive_max_pool2d_x_0 # output_2 -> aten::adaptive_max_pool2d_x_1 + """ + def get_node_source_name(self, node): + op_idx = self.increment_op_type_idx(node) + return "_".join([op_type, str(op_idx)]) + + def get_node_output_name(self, node, node_src_name, index): + return "_".join([node_src_name, str(index)]) + + +class PytorchScopePreservingNamer(NodeNamer): + MODULE_PREFIX = "__module." + + def get_node_source_name(self, node): + node_src_name = node.scopeName().split("/")[-1] + if node_src_name.startswith(self.MODULE_PREFIX): + node_src_name = node_src_name[len(self.MODULE_PREFIX):] + return node_src_name + + def get_node_output_name(self, node, node_src_name, index): + op_idx = self.increment_op_type_idx(node) + return "_".join([node_src_name, str(op_idx), str(index)]) + + +def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes): + """Rewrite debug name of node outputs with its operator type""" + namer = ( + PytorchScopePreservingNamer(op_type_dict) + if preserve_pytorch_scopes + else DefaultNodeKindNamer(op_type_dict) + ) + # get source name of operator and rename all of its outputs if node.kind() != "prim::GetAttr": - node_src_name = _get_source_name(node.kind()) + node_src_name = namer.get_node_source_name(node) for index, output in enumerate(node.outputs()): - output.setDebugName("_".join([node_src_name, str(index)])) + name = node.get_node_output_name(node, node_src_name, index) + output.setDebugName(name) # update source map # if use_parser_friendly_name is True: e.g. prim::Constant_0 -> prim__Constant_0 if use_parser_friendly_name: @@ -4797,7 +4841,7 @@ def _get_source_name(op_type): source_map[node] = node_src_name -def _debug_rename(graph, use_parser_friendly_name): +def _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes): """Returns map between node and source name""" source_map, op_type_dict = {}, {} prim_with_blocks = ["prim::If", "prim::Loop"] @@ -4809,13 +4853,13 @@ def _traverse_graph(nodes): if node.kind() in prim_with_blocks: for block in node.blocks(): _traverse_graph(block.nodes()) - _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name) + _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes) _traverse_graph(graph.nodes()) return source_map -def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_friendly_name=False): +def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_friendly_name=False, preserve_pytorch_scopes=False): """Returns torch IR nodes that need conversion to Relay""" ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None @@ -4825,7 +4869,7 @@ def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_fr continue if should_rename_graph: - _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name) + _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes) if node.outputsSize() > 1: node_name = "_".join(_get_output_names(node)) @@ -5080,6 +5124,7 @@ def from_pytorch( use_parser_friendly_name=False, keep_quantized_weight=False, export_renamed_c_graph_path=None, + preserve_pytorch_scopes=False ): """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -5126,6 +5171,10 @@ def from_pytorch( Export the renamed torch._C.Graph to the path. During the conversion, variable names in torch._C.Graph will be assigned based on their op types. The exported text file can be the reference to spans. + + preserve_pytorch_scopes : bool + When naming the different nodes in the TVM graph, use the "scope name" from the Pytorch graph. + If false, a default namer is used that does not preserve the Pytorch scope names. Returns ------- @@ -5141,7 +5190,7 @@ def from_pytorch( prelude = Prelude(mod) enable_lower_all_tuples = True - converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name) + converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name, preserve_pytorch_scopes) graph = script_module.graph.copy() @@ -5173,7 +5222,7 @@ def from_pytorch( # rename _C.Graph here for constructing meaningful source name of graph nodes # by doing so, we could Use source_map as the reference to rename model parameters - source_map = _debug_rename(graph, use_parser_friendly_name) + source_map = _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes) param_vars, tensors, packed_param_map, param_debug_name_map = convert_params( graph, params, source_map, use_parser_friendly_name ) @@ -5201,7 +5250,7 @@ def from_pytorch( converter.update_convert_map(qnn_torch.convert_map) operator_nodes = _get_operator_nodes( - graph.nodes(), converter.source_map, converter.op_type_dict, use_parser_friendly_name + graph.nodes(), converter.source_map, converter.op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes ) ret_name = _get_input_names(graph.return_node()) outputs = converter.convert_operators(operator_nodes, outputs, ret_name) From 7ea2233a0c1987992dd3c7d03e05d408744e421c Mon Sep 17 00:00:00 2001 From: Navya Mehta <134946169+navya-encharge@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:19:16 -0500 Subject: [PATCH 02/10] Update pytorch.py --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 660623b83a44..8b59e6647606 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -21,7 +21,7 @@ """PT: PyTorch frontend.""" import functools import itertools -from typing import ABC +from abc import ABC import math import re import sys From ed52226e7bb6a336ab9b5bb791d32b811c9a6593 Mon Sep 17 00:00:00 2001 From: Navya Mehta Date: Mon, 27 Nov 2023 15:46:03 -0500 Subject: [PATCH 03/10] Add tests --- .../frontend/pytorch/test_span_naming.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/python/frontend/pytorch/test_span_naming.py diff --git a/tests/python/frontend/pytorch/test_span_naming.py b/tests/python/frontend/pytorch/test_span_naming.py new file mode 100644 index 000000000000..65c18b2cdfd1 --- /dev/null +++ b/tests/python/frontend/pytorch/test_span_naming.py @@ -0,0 +1,63 @@ +import torch.nn as nn +import torch +import pytest +import tvm + + +class NestedConvModule(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.conv(x)) + return x + + +class SimpleTwoConvModule(nn.Module): + def __init__(self): + super().__init__() + # First convolutional module + self.image_block1 = NestedConvModule(in_channels=3, out_channels=64) + # Second convolutional module + self.image_block2 = NestedConvModule(in_channels=64, out_channels=64) + + def forward(self, x): + # Forward pass through the first convolutional module + x1 = self.image_block1(x) + # Forward pass through the second convolutional module + x2 = self.image_block2(x1) + # Add the outputs of the two convolutional modules + output = x1 + x2 + return output + + +@pytest.fixture +def traced_model_and_inputs(): + model = SimpleTwoConvModule() + sample_input = torch.zeros((1, 3, 64, 64), dtype=torch.float32) + with torch.no_grad(): + traced_torch_model = torch.jit.trace(model, sample_input) + import_input = ["model_input", (1, 3, 64, 64)] + return traced_torch_model, import_input + + +def test_default_span_names(traced_model_and_inputs): + traced_torch_model, import_input = traced_model_and_inputs + relay_model_ir, relay_model_params = tvm.relay.frontend.from_pytorch( + traced_torch_model, import_input + ) + # By default, we assign new names based on the op kind + import pdb + pdb.set_trace() + + +def test_pytorch_scope_based_span_names(traced_model_and_inputs): + traced_torch_model, import_input = traced_model_and_inputs + relay_model_ir, relay_model_params = tvm.relay.frontend.from_pytorch( + traced_torch_model, import_input, preserve_pytorch_scopes=True + ) + # If specified, we are preserving the pytorch named + import pdb + pdb.set_trace() From d7a26575a348e223bbfd5a6d8f853fb537f79fbb Mon Sep 17 00:00:00 2001 From: Navya Mehta Date: Mon, 27 Nov 2023 12:56:56 -0800 Subject: [PATCH 04/10] WIP --- python/tvm/relay/frontend/pytorch.py | 18 +++++++++++------- .../frontend/pytorch/test_span_naming.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 8b59e6647606..af49c725b470 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4674,7 +4674,11 @@ def _get_input_names(node_or_graph): def _get_op_inputs(op_node, outputs): - return [outputs[name] for name in _get_input_names(op_node)] + try: + return [outputs[name] for name in _get_input_names(op_node)] + except: + import pdb + pdb.set_trace() def _get_node_type(node): @@ -4777,12 +4781,12 @@ class NodeNamer(ABC): def __init__(self, op_type_dict): self._op_type_dict = op_type_dict - def increment_op_type_idx(node): + def increment_op_type_idx(self, node): op_type = node.kind() op_idx = 0 - if op_type in op_type_dict: - op_idx = op_type_dict[op_type] + 1 - op_type_dict[op_type] = op_idx + if op_type in self._op_type_dict: + op_idx = self._op_type_dict[op_type] + 1 + self._op_type_dict[op_type] = op_idx return op_idx def get_node_source_name(self, node): @@ -4801,7 +4805,7 @@ class DefaultNodeKindNamer(NodeNamer): """ def get_node_source_name(self, node): op_idx = self.increment_op_type_idx(node) - return "_".join([op_type, str(op_idx)]) + return "_".join([node.kind(), str(op_idx)]) def get_node_output_name(self, node, node_src_name, index): return "_".join([node_src_name, str(index)]) @@ -4832,7 +4836,7 @@ def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name, pr if node.kind() != "prim::GetAttr": node_src_name = namer.get_node_source_name(node) for index, output in enumerate(node.outputs()): - name = node.get_node_output_name(node, node_src_name, index) + name = namer.get_node_output_name(node, node_src_name, index) output.setDebugName(name) # update source map # if use_parser_friendly_name is True: e.g. prim::Constant_0 -> prim__Constant_0 diff --git a/tests/python/frontend/pytorch/test_span_naming.py b/tests/python/frontend/pytorch/test_span_naming.py index 65c18b2cdfd1..82b47c050dcc 100644 --- a/tests/python/frontend/pytorch/test_span_naming.py +++ b/tests/python/frontend/pytorch/test_span_naming.py @@ -39,7 +39,7 @@ def traced_model_and_inputs(): sample_input = torch.zeros((1, 3, 64, 64), dtype=torch.float32) with torch.no_grad(): traced_torch_model = torch.jit.trace(model, sample_input) - import_input = ["model_input", (1, 3, 64, 64)] + import_input = [("model_input", (1, 3, 64, 64))] return traced_torch_model, import_input From 1935e964a97ef59028ca4445718e41fac62de2d6 Mon Sep 17 00:00:00 2001 From: Navya Mehta Date: Mon, 27 Nov 2023 16:49:56 -0500 Subject: [PATCH 05/10] Changes and tests --- python/tvm/relay/frontend/pytorch.py | 78 ++++++++++++------- .../frontend/pytorch/test_span_naming.py | 59 ++++++++------ 2 files changed, 85 insertions(+), 52 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index af49c725b470..8ae586066cb1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -138,7 +138,9 @@ def _is_int_seq(seq): class PyTorchOpConverter: """A helper class for holding PyTorch op converters.""" - def __init__(self, prelude, default_dtype, use_parser_friendly_name=False, preserve_pytorch_scopes=False): + def __init__( + self, prelude, default_dtype, use_parser_friendly_name=False, preserve_pytorch_scopes=False + ): self.prelude = prelude self.default_dtype = default_dtype self.create_convert_map() @@ -4206,7 +4208,11 @@ def report_missing_conversion(self, op_names): def convert_block(self, block, outputs): """Translate Torch "Block", used for prim::If and prim::Loop""" ops = _get_operator_nodes( - block.nodes(), self.source_map, self.op_type_dict, self.use_parser_friendly_name, self.preserve_pytorch_scopes + block.nodes(), + self.source_map, + self.op_type_dict, + self.use_parser_friendly_name, + self.preserve_pytorch_scopes, ) ret_names = _get_input_names(block.returnNode()) return self.convert_operators(ops, outputs, ret_names) @@ -4674,11 +4680,7 @@ def _get_input_names(node_or_graph): def _get_op_inputs(op_node, outputs): - try: - return [outputs[name] for name in _get_input_names(op_node)] - except: - import pdb - pdb.set_trace() + return [outputs[name] for name in _get_input_names(op_node)] def _get_node_type(node): @@ -4778,17 +4780,16 @@ def _get_constant(node): class NodeNamer(ABC): - def __init__(self, op_type_dict): - self._op_type_dict = op_type_dict - - def increment_op_type_idx(self, node): - op_type = node.kind() + def __init__(self, op_counter_dict): + self.op_counter_dict = op_counter_dict + + def increment_counter(self, identifier): op_idx = 0 - if op_type in self._op_type_dict: - op_idx = self._op_type_dict[op_type] + 1 - self._op_type_dict[op_type] = op_idx + if identifier in self.op_counter_dict: + op_idx = self.op_counter_dict[identifier] + 1 + self.op_counter_dict[identifier] = op_idx return op_idx - + def get_node_source_name(self, node): raise NotImplementedError() @@ -4803,8 +4804,9 @@ class DefaultNodeKindNamer(NodeNamer): # output_1 -> aten::adaptive_max_pool2d_x_0 # output_2 -> aten::adaptive_max_pool2d_x_1 """ + def get_node_source_name(self, node): - op_idx = self.increment_op_type_idx(node) + op_idx = self.increment_counter(node.kind()) return "_".join([node.kind(), str(op_idx)]) def get_node_output_name(self, node, node_src_name, index): @@ -4817,19 +4819,21 @@ class PytorchScopePreservingNamer(NodeNamer): def get_node_source_name(self, node): node_src_name = node.scopeName().split("/")[-1] if node_src_name.startswith(self.MODULE_PREFIX): - node_src_name = node_src_name[len(self.MODULE_PREFIX):] + node_src_name = node_src_name[len(self.MODULE_PREFIX) :] return node_src_name def get_node_output_name(self, node, node_src_name, index): - op_idx = self.increment_op_type_idx(node) + op_idx = self.increment_counter(node_src_name) return "_".join([node_src_name, str(op_idx), str(index)]) -def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes): +def _rename_outputs( + node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes +): """Rewrite debug name of node outputs with its operator type""" namer = ( - PytorchScopePreservingNamer(op_type_dict) - if preserve_pytorch_scopes + PytorchScopePreservingNamer(op_type_dict) + if preserve_pytorch_scopes else DefaultNodeKindNamer(op_type_dict) ) # get source name of operator and rename all of its outputs @@ -4857,13 +4861,21 @@ def _traverse_graph(nodes): if node.kind() in prim_with_blocks: for block in node.blocks(): _traverse_graph(block.nodes()) - _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes) + _rename_outputs( + node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes + ) _traverse_graph(graph.nodes()) return source_map -def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_friendly_name=False, preserve_pytorch_scopes=False): +def _get_operator_nodes( + nodes, + source_map=None, + op_type_dict=None, + use_parser_friendly_name=False, + preserve_pytorch_scopes=False, +): """Returns torch IR nodes that need conversion to Relay""" ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None @@ -4873,7 +4885,9 @@ def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_fr continue if should_rename_graph: - _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes) + _rename_outputs( + node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes + ) if node.outputsSize() > 1: node_name = "_".join(_get_output_names(node)) @@ -5128,7 +5142,7 @@ def from_pytorch( use_parser_friendly_name=False, keep_quantized_weight=False, export_renamed_c_graph_path=None, - preserve_pytorch_scopes=False + preserve_pytorch_scopes=False, ): """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -5175,7 +5189,7 @@ def from_pytorch( Export the renamed torch._C.Graph to the path. During the conversion, variable names in torch._C.Graph will be assigned based on their op types. The exported text file can be the reference to spans. - + preserve_pytorch_scopes : bool When naming the different nodes in the TVM graph, use the "scope name" from the Pytorch graph. If false, a default namer is used that does not preserve the Pytorch scope names. @@ -5194,7 +5208,9 @@ def from_pytorch( prelude = Prelude(mod) enable_lower_all_tuples = True - converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name, preserve_pytorch_scopes) + converter = PyTorchOpConverter( + prelude, default_dtype, use_parser_friendly_name, preserve_pytorch_scopes + ) graph = script_module.graph.copy() @@ -5254,7 +5270,11 @@ def from_pytorch( converter.update_convert_map(qnn_torch.convert_map) operator_nodes = _get_operator_nodes( - graph.nodes(), converter.source_map, converter.op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes + graph.nodes(), + converter.source_map, + converter.op_type_dict, + use_parser_friendly_name, + preserve_pytorch_scopes, ) ret_name = _get_input_names(graph.return_node()) outputs = converter.convert_operators(operator_nodes, outputs, ret_name) diff --git a/tests/python/frontend/pytorch/test_span_naming.py b/tests/python/frontend/pytorch/test_span_naming.py index 82b47c050dcc..0fd3ef6a1d5c 100644 --- a/tests/python/frontend/pytorch/test_span_naming.py +++ b/tests/python/frontend/pytorch/test_span_naming.py @@ -1,6 +1,5 @@ import torch.nn as nn import torch -import pytest import tvm @@ -15,6 +14,11 @@ def forward(self, x): return x +class NestedFinalModule(nn.Module): + def forward(self, x, y): + return x + y + + class SimpleTwoConvModule(nn.Module): def __init__(self): super().__init__() @@ -22,6 +26,7 @@ def __init__(self): self.image_block1 = NestedConvModule(in_channels=3, out_channels=64) # Second convolutional module self.image_block2 = NestedConvModule(in_channels=64, out_channels=64) + self.final_block = NestedFinalModule() def forward(self, x): # Forward pass through the first convolutional module @@ -29,35 +34,43 @@ def forward(self, x): # Forward pass through the second convolutional module x2 = self.image_block2(x1) # Add the outputs of the two convolutional modules - output = x1 + x2 - return output + return self.final_block(x1, x2) -@pytest.fixture -def traced_model_and_inputs(): +def test_pytorch_scope_based_span_names(): model = SimpleTwoConvModule() sample_input = torch.zeros((1, 3, 64, 64), dtype=torch.float32) with torch.no_grad(): traced_torch_model = torch.jit.trace(model, sample_input) import_input = [("model_input", (1, 3, 64, 64))] - return traced_torch_model, import_input - - -def test_default_span_names(traced_model_and_inputs): - traced_torch_model, import_input = traced_model_and_inputs - relay_model_ir, relay_model_params = tvm.relay.frontend.from_pytorch( - traced_torch_model, import_input - ) - # By default, we assign new names based on the op kind - import pdb - pdb.set_trace() - - -def test_pytorch_scope_based_span_names(traced_model_and_inputs): - traced_torch_model, import_input = traced_model_and_inputs relay_model_ir, relay_model_params = tvm.relay.frontend.from_pytorch( traced_torch_model, import_input, preserve_pytorch_scopes=True ) - # If specified, we are preserving the pytorch named - import pdb - pdb.set_trace() + # If specified, we are preserving the pytorch named spans + for block in [1, 2]: + for key in ["weight", "bias"]: + assert f"image_block{block}.conv.{key}" in relay_model_params.keys() + # Manually check all span names since asserting structural equality is not sufficient + current_call = relay_model_ir["main"].body + assert current_call.op.name == "add" + assert current_call.span is not None and current_call.span.source_name.name == "final_block" + current_call = current_call.args[1] + for block in [2, 1]: + assert current_call.op.name == "nn.relu" + assert ( + current_call.span is not None + and current_call.span.source_name.name == f"image_block{block}.relu" + ) + current_call = current_call.args[0] + assert current_call.op.name == "nn.bias_add" + assert ( + current_call.span is not None + and current_call.span.source_name.name == f"image_block{block}.conv" + ) + current_call = current_call.args[0] + assert current_call.op.name == "nn.conv2d" + assert ( + current_call.span is not None + and current_call.span.source_name.name == f"image_block{block}.conv" + ) + current_call = current_call.args[0] From c7a2d61bb3d24fb5f4d7f5dce5df81dc0883c047 Mon Sep 17 00:00:00 2001 From: Navya Mehta Date: Wed, 29 Nov 2023 18:14:11 -0500 Subject: [PATCH 06/10] Michael Klaiber feedback --- python/tvm/relay/frontend/pytorch.py | 29 ++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 8ae586066cb1..f38bae4291e1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -22,6 +22,7 @@ import functools import itertools from abc import ABC +from typing import Dict import math import re import sys @@ -4780,49 +4781,57 @@ def _get_constant(node): class NodeNamer(ABC): - def __init__(self, op_counter_dict): + """Name each node and output edge in the relay graph""" + + def __init__(self, op_counter_dict: Dict[str, int]): self.op_counter_dict = op_counter_dict - def increment_counter(self, identifier): + def increment_counter(self, identifier: str) -> int: op_idx = 0 if identifier in self.op_counter_dict: op_idx = self.op_counter_dict[identifier] + 1 self.op_counter_dict[identifier] = op_idx return op_idx - def get_node_source_name(self, node): + def get_node_source_name(self, node) -> str: raise NotImplementedError() - def get_node_output_name(self, node, node_source_name, index): + def get_node_output_name(self, node_source_name: str, index: int) -> str: raise NotImplementedError() class DefaultNodeKindNamer(NodeNamer): """ + Namer that uses a default naming based on the "type"/kind of node # e.g. node.kind(): aten::adaptive_max_pool2d # node_src_name -> aten::adaptive_max_pool2d_x # output_1 -> aten::adaptive_max_pool2d_x_0 # output_2 -> aten::adaptive_max_pool2d_x_1 """ - def get_node_source_name(self, node): + def get_node_source_name(self, node) -> str: op_idx = self.increment_counter(node.kind()) return "_".join([node.kind(), str(op_idx)]) - def get_node_output_name(self, node, node_src_name, index): + def get_node_output_name(self, node_src_name: str, index: int) -> str: return "_".join([node_src_name, str(index)]) class PytorchScopePreservingNamer(NodeNamer): + """ + Namer that uses the Pytorch scope to name nodes. + eg. node could be called "bert.encoder.layer.11.output.dense" + """ + MODULE_PREFIX = "__module." - def get_node_source_name(self, node): + def get_node_source_name(self, node) -> str: node_src_name = node.scopeName().split("/")[-1] if node_src_name.startswith(self.MODULE_PREFIX): node_src_name = node_src_name[len(self.MODULE_PREFIX) :] return node_src_name - def get_node_output_name(self, node, node_src_name, index): + def get_node_output_name(self, node_src_name: str, index: int) -> str: op_idx = self.increment_counter(node_src_name) return "_".join([node_src_name, str(op_idx), str(index)]) @@ -4840,7 +4849,7 @@ def _rename_outputs( if node.kind() != "prim::GetAttr": node_src_name = namer.get_node_source_name(node) for index, output in enumerate(node.outputs()): - name = namer.get_node_output_name(node, node_src_name, index) + name = namer.get_node_output_name(node_src_name, index) output.setDebugName(name) # update source map # if use_parser_friendly_name is True: e.g. prim::Constant_0 -> prim__Constant_0 @@ -5191,7 +5200,7 @@ def from_pytorch( types. The exported text file can be the reference to spans. preserve_pytorch_scopes : bool - When naming the different nodes in the TVM graph, use the "scope name" from the Pytorch graph. + When naming the different nodes in the Relay graph, use the "scope name" from the Pytorch graph. If false, a default namer is used that does not preserve the Pytorch scope names. Returns From db89502b1d3179b2a78dc8b20c1ab58cf2132f3d Mon Sep 17 00:00:00 2001 From: Navya Mehta Date: Wed, 29 Nov 2023 22:04:24 -0500 Subject: [PATCH 07/10] Linting fix --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f38bae4291e1..84a7296cf268 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4796,7 +4796,7 @@ def increment_counter(self, identifier: str) -> int: def get_node_source_name(self, node) -> str: raise NotImplementedError() - def get_node_output_name(self, node_source_name: str, index: int) -> str: + def get_node_output_name(self, node_src_name: str, index: int) -> str: raise NotImplementedError() From 898208d13b5f97ebd06d265c16faee3ea949ac60 Mon Sep 17 00:00:00 2001 From: Navya Mehta Date: Wed, 29 Nov 2023 22:35:29 -0500 Subject: [PATCH 08/10] Linting Feedback Pt.2 --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 84a7296cf268..403df580f761 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -5200,7 +5200,7 @@ def from_pytorch( types. The exported text file can be the reference to spans. preserve_pytorch_scopes : bool - When naming the different nodes in the Relay graph, use the "scope name" from the Pytorch graph. + When naming the nodes in the Relay graph, use the "scope name" from the Pytorch model. If false, a default namer is used that does not preserve the Pytorch scope names. Returns From 570962fbd9135fc1270be5195e8055e9b6d7f1b4 Mon Sep 17 00:00:00 2001 From: Navya Mehta Date: Thu, 30 Nov 2023 09:44:56 -0500 Subject: [PATCH 09/10] Test changes --- .../frontend/pytorch/test_span_naming.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_span_naming.py b/tests/python/frontend/pytorch/test_span_naming.py index 0fd3ef6a1d5c..fb39ddf4f061 100644 --- a/tests/python/frontend/pytorch/test_span_naming.py +++ b/tests/python/frontend/pytorch/test_span_naming.py @@ -1,9 +1,32 @@ -import torch.nn as nn +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks +# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except +# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda +# pylint: disable=missing-function-docstring, redefined-builtin, use-implicit-booleaness-not-comparison +"""Tests to ensure span names are correctly populated when importing Pytorch""" +from torch import nn import torch import tvm class NestedConvModule(nn.Module): + """Module that performs Conv2d and relu activation""" + def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -15,11 +38,18 @@ def forward(self, x): class NestedFinalModule(nn.Module): + """Simple module that adds 2 inputs""" + def forward(self, x, y): return x + y class SimpleTwoConvModule(nn.Module): + """ + ML model that performs 2 convolutions and adds them together. + All operations are inside nested modules to make scope names interesting. + """ + def __init__(self): super().__init__() # First convolutional module From 28de8f6ef86b48a0c15a3ed634046a2e6739edcf Mon Sep 17 00:00:00 2001 From: Navya Mehta Date: Wed, 6 Dec 2023 08:18:46 -0800 Subject: [PATCH 10/10] Modify to Pytorch 2.0 --- python/tvm/relay/frontend/pytorch.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 403df580f761..2731fee1035c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4823,12 +4823,11 @@ class PytorchScopePreservingNamer(NodeNamer): eg. node could be called "bert.encoder.layer.11.output.dense" """ - MODULE_PREFIX = "__module." - def get_node_source_name(self, node) -> str: - node_src_name = node.scopeName().split("/")[-1] - if node_src_name.startswith(self.MODULE_PREFIX): - node_src_name = node_src_name[len(self.MODULE_PREFIX) :] + # This works per the scope naming in Pytorch 2.0 and beyond. + scope_name_parts = node.scopeName().split("/") + imp_parts = [part.split("::")[-1] for part in scope_name_parts] + node_src_name = ".".join([part for part in imp_parts if part]) return node_src_name def get_node_output_name(self, node_src_name: str, index: int) -> str: