-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay][Frontend] Preserve Pytorch Span Names #16171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
b73aaae
Preserve Pytorch Span Names
navya-encharge 7ea2233
Update pytorch.py
navya-encharge ed52226
Add tests
navya-encharge d7a2657
WIP
navya-encharge 1935e96
Changes and tests
navya-encharge c7a2d61
Michael Klaiber feedback
navya-encharge db89502
Linting fix
navya-encharge 898208d
Linting Feedback Pt.2
navya-encharge 570962f
Test changes
navya-encharge 28de8f6
Modify to Pytorch 2.0
navya-encharge File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,8 @@ | |
| """PT: PyTorch frontend.""" | ||
| import functools | ||
| import itertools | ||
| from abc import ABC | ||
| from typing import Dict | ||
| import math | ||
| import re | ||
| import sys | ||
|
|
@@ -137,7 +139,9 @@ def _is_int_seq(seq): | |
| class PyTorchOpConverter: | ||
| """A helper class for holding PyTorch op converters.""" | ||
|
|
||
| def __init__(self, prelude, default_dtype, use_parser_friendly_name=False): | ||
| def __init__( | ||
| self, prelude, default_dtype, use_parser_friendly_name=False, preserve_pytorch_scopes=False | ||
| ): | ||
| self.prelude = prelude | ||
| self.default_dtype = default_dtype | ||
| self.create_convert_map() | ||
|
|
@@ -146,6 +150,7 @@ def __init__(self, prelude, default_dtype, use_parser_friendly_name=False): | |
| self.op_type_dict = {} # map from op type to its presenting order | ||
| self.current_op = [] # stack for recording current processing op | ||
| self.use_parser_friendly_name = use_parser_friendly_name | ||
| self.preserve_pytorch_scopes = preserve_pytorch_scopes | ||
|
|
||
| # this incrementally infers the type, see the comments on the type visitor | ||
| # above. | ||
|
|
@@ -4204,7 +4209,11 @@ def report_missing_conversion(self, op_names): | |
| def convert_block(self, block, outputs): | ||
| """Translate Torch "Block", used for prim::If and prim::Loop""" | ||
| ops = _get_operator_nodes( | ||
| block.nodes(), self.source_map, self.op_type_dict, self.use_parser_friendly_name | ||
| block.nodes(), | ||
| self.source_map, | ||
| self.op_type_dict, | ||
| self.use_parser_friendly_name, | ||
| self.preserve_pytorch_scopes, | ||
| ) | ||
| ret_names = _get_input_names(block.returnNode()) | ||
| return self.convert_operators(ops, outputs, ret_names) | ||
|
|
@@ -4771,33 +4780,84 @@ def _get_constant(node): | |
| return None | ||
|
|
||
|
|
||
| def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name): | ||
| """Rewrite debug name of node outputs with its operator type""" | ||
| class NodeNamer(ABC): | ||
| """Name each node and output edge in the relay graph""" | ||
|
|
||
| def _get_source_name(op_type): | ||
| def __init__(self, op_counter_dict: Dict[str, int]): | ||
| self.op_counter_dict = op_counter_dict | ||
|
|
||
| def increment_counter(self, identifier: str) -> int: | ||
| op_idx = 0 | ||
| if op_type in op_type_dict: | ||
| op_idx = op_type_dict[op_type] + 1 | ||
| op_type_dict[op_type] = op_idx | ||
| return "_".join([op_type, str(op_idx)]) | ||
| if identifier in self.op_counter_dict: | ||
| op_idx = self.op_counter_dict[identifier] + 1 | ||
| self.op_counter_dict[identifier] = op_idx | ||
| return op_idx | ||
|
|
||
| # get source name of operator and rename all of its outputs | ||
| def get_node_source_name(self, node) -> str: | ||
| raise NotImplementedError() | ||
|
|
||
| def get_node_output_name(self, node_src_name: str, index: int) -> str: | ||
| raise NotImplementedError() | ||
|
|
||
|
|
||
| class DefaultNodeKindNamer(NodeNamer): | ||
| """ | ||
| Namer that uses a default naming based on the "type"/kind of node | ||
| # e.g. node.kind(): aten::adaptive_max_pool2d | ||
| # node_src_name -> aten::adaptive_max_pool2d_x | ||
| # output_1 -> aten::adaptive_max_pool2d_x_0 | ||
| # output_2 -> aten::adaptive_max_pool2d_x_1 | ||
| """ | ||
|
|
||
| def get_node_source_name(self, node) -> str: | ||
| op_idx = self.increment_counter(node.kind()) | ||
| return "_".join([node.kind(), str(op_idx)]) | ||
|
|
||
| def get_node_output_name(self, node_src_name: str, index: int) -> str: | ||
| return "_".join([node_src_name, str(index)]) | ||
|
|
||
|
|
||
| class PytorchScopePreservingNamer(NodeNamer): | ||
navya-encharge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Namer that uses the Pytorch scope to name nodes. | ||
| eg. node could be called "bert.encoder.layer.11.output.dense" | ||
| """ | ||
|
|
||
| def get_node_source_name(self, node) -> str: | ||
| # This works per the scope naming in Pytorch 2.0 and beyond. | ||
| scope_name_parts = node.scopeName().split("/") | ||
| imp_parts = [part.split("::")[-1] for part in scope_name_parts] | ||
| node_src_name = ".".join([part for part in imp_parts if part]) | ||
| return node_src_name | ||
|
|
||
| def get_node_output_name(self, node_src_name: str, index: int) -> str: | ||
| op_idx = self.increment_counter(node_src_name) | ||
| return "_".join([node_src_name, str(op_idx), str(index)]) | ||
|
|
||
|
|
||
| def _rename_outputs( | ||
| node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes | ||
| ): | ||
| """Rewrite debug name of node outputs with its operator type""" | ||
| namer = ( | ||
| PytorchScopePreservingNamer(op_type_dict) | ||
| if preserve_pytorch_scopes | ||
| else DefaultNodeKindNamer(op_type_dict) | ||
| ) | ||
| # get source name of operator and rename all of its outputs | ||
| if node.kind() != "prim::GetAttr": | ||
| node_src_name = _get_source_name(node.kind()) | ||
| node_src_name = namer.get_node_source_name(node) | ||
| for index, output in enumerate(node.outputs()): | ||
| output.setDebugName("_".join([node_src_name, str(index)])) | ||
| name = namer.get_node_output_name(node_src_name, index) | ||
| output.setDebugName(name) | ||
| # update source map | ||
| # if use_parser_friendly_name is True: e.g. prim::Constant_0 -> prim__Constant_0 | ||
| if use_parser_friendly_name: | ||
| node_src_name = re.sub(r":|\.", "_", node_src_name) | ||
| source_map[node] = node_src_name | ||
|
|
||
|
|
||
| def _debug_rename(graph, use_parser_friendly_name): | ||
| def _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes): | ||
| """Returns map between node and source name""" | ||
| source_map, op_type_dict = {}, {} | ||
| prim_with_blocks = ["prim::If", "prim::Loop"] | ||
|
|
@@ -4809,13 +4869,21 @@ def _traverse_graph(nodes): | |
| if node.kind() in prim_with_blocks: | ||
| for block in node.blocks(): | ||
| _traverse_graph(block.nodes()) | ||
| _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name) | ||
| _rename_outputs( | ||
| node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes | ||
| ) | ||
|
|
||
| _traverse_graph(graph.nodes()) | ||
| return source_map | ||
|
|
||
|
|
||
| def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_friendly_name=False): | ||
| def _get_operator_nodes( | ||
| nodes, | ||
| source_map=None, | ||
| op_type_dict=None, | ||
| use_parser_friendly_name=False, | ||
| preserve_pytorch_scopes=False, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it can be set to True by default, can be done as a follow up. |
||
| ): | ||
| """Returns torch IR nodes that need conversion to Relay""" | ||
| ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None | ||
|
|
||
|
|
@@ -4825,7 +4893,9 @@ def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_fr | |
| continue | ||
|
|
||
| if should_rename_graph: | ||
| _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name) | ||
| _rename_outputs( | ||
| node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes | ||
| ) | ||
|
|
||
| if node.outputsSize() > 1: | ||
| node_name = "_".join(_get_output_names(node)) | ||
|
|
@@ -5080,6 +5150,7 @@ def from_pytorch( | |
| use_parser_friendly_name=False, | ||
| keep_quantized_weight=False, | ||
| export_renamed_c_graph_path=None, | ||
| preserve_pytorch_scopes=False, | ||
| ): | ||
| """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. | ||
| The companion parameters will be handled automatically. | ||
|
|
@@ -5127,6 +5198,10 @@ def from_pytorch( | |
| During the conversion, variable names in torch._C.Graph will be assigned based on their op | ||
| types. The exported text file can be the reference to spans. | ||
|
|
||
| preserve_pytorch_scopes : bool | ||
| When naming the nodes in the Relay graph, use the "scope name" from the Pytorch model. | ||
| If false, a default namer is used that does not preserve the Pytorch scope names. | ||
|
|
||
| Returns | ||
| ------- | ||
| mod : tvm.IRModule | ||
|
|
@@ -5141,7 +5216,9 @@ def from_pytorch( | |
| prelude = Prelude(mod) | ||
| enable_lower_all_tuples = True | ||
|
|
||
| converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name) | ||
| converter = PyTorchOpConverter( | ||
| prelude, default_dtype, use_parser_friendly_name, preserve_pytorch_scopes | ||
| ) | ||
|
|
||
| graph = script_module.graph.copy() | ||
|
|
||
|
|
@@ -5173,7 +5250,7 @@ def from_pytorch( | |
|
|
||
| # rename _C.Graph here for constructing meaningful source name of graph nodes | ||
| # by doing so, we could Use source_map as the reference to rename model parameters | ||
| source_map = _debug_rename(graph, use_parser_friendly_name) | ||
| source_map = _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes) | ||
| param_vars, tensors, packed_param_map, param_debug_name_map = convert_params( | ||
| graph, params, source_map, use_parser_friendly_name | ||
| ) | ||
|
|
@@ -5201,7 +5278,11 @@ def from_pytorch( | |
| converter.update_convert_map(qnn_torch.convert_map) | ||
|
|
||
| operator_nodes = _get_operator_nodes( | ||
| graph.nodes(), converter.source_map, converter.op_type_dict, use_parser_friendly_name | ||
| graph.nodes(), | ||
| converter.source_map, | ||
| converter.op_type_dict, | ||
| use_parser_friendly_name, | ||
| preserve_pytorch_scopes, | ||
| ) | ||
| ret_name = _get_input_names(graph.return_node()) | ||
| outputs = converter.convert_operators(operator_nodes, outputs, ret_name) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| # pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks | ||
| # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except | ||
| # pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda | ||
| # pylint: disable=missing-function-docstring, redefined-builtin, use-implicit-booleaness-not-comparison | ||
| """Tests to ensure span names are correctly populated when importing Pytorch""" | ||
| from torch import nn | ||
| import torch | ||
| import tvm | ||
|
|
||
|
|
||
| class NestedConvModule(nn.Module): | ||
| """Module that performs Conv2d and relu activation""" | ||
|
|
||
| def __init__(self, in_channels, out_channels): | ||
| super().__init__() | ||
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | ||
| self.relu = nn.ReLU() | ||
|
|
||
| def forward(self, x): | ||
| x = self.relu(self.conv(x)) | ||
| return x | ||
|
|
||
|
|
||
| class NestedFinalModule(nn.Module): | ||
| """Simple module that adds 2 inputs""" | ||
|
|
||
| def forward(self, x, y): | ||
| return x + y | ||
|
|
||
|
|
||
| class SimpleTwoConvModule(nn.Module): | ||
| """ | ||
| ML model that performs 2 convolutions and adds them together. | ||
| All operations are inside nested modules to make scope names interesting. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| # First convolutional module | ||
| self.image_block1 = NestedConvModule(in_channels=3, out_channels=64) | ||
| # Second convolutional module | ||
| self.image_block2 = NestedConvModule(in_channels=64, out_channels=64) | ||
| self.final_block = NestedFinalModule() | ||
|
|
||
| def forward(self, x): | ||
| # Forward pass through the first convolutional module | ||
| x1 = self.image_block1(x) | ||
| # Forward pass through the second convolutional module | ||
| x2 = self.image_block2(x1) | ||
| # Add the outputs of the two convolutional modules | ||
| return self.final_block(x1, x2) | ||
|
|
||
|
|
||
| def test_pytorch_scope_based_span_names(): | ||
| model = SimpleTwoConvModule() | ||
| sample_input = torch.zeros((1, 3, 64, 64), dtype=torch.float32) | ||
| with torch.no_grad(): | ||
| traced_torch_model = torch.jit.trace(model, sample_input) | ||
| import_input = [("model_input", (1, 3, 64, 64))] | ||
| relay_model_ir, relay_model_params = tvm.relay.frontend.from_pytorch( | ||
| traced_torch_model, import_input, preserve_pytorch_scopes=True | ||
| ) | ||
| # If specified, we are preserving the pytorch named spans | ||
| for block in [1, 2]: | ||
| for key in ["weight", "bias"]: | ||
| assert f"image_block{block}.conv.{key}" in relay_model_params.keys() | ||
| # Manually check all span names since asserting structural equality is not sufficient | ||
| current_call = relay_model_ir["main"].body | ||
| assert current_call.op.name == "add" | ||
| assert current_call.span is not None and current_call.span.source_name.name == "final_block" | ||
| current_call = current_call.args[1] | ||
| for block in [2, 1]: | ||
| assert current_call.op.name == "nn.relu" | ||
| assert ( | ||
| current_call.span is not None | ||
| and current_call.span.source_name.name == f"image_block{block}.relu" | ||
| ) | ||
| current_call = current_call.args[0] | ||
| assert current_call.op.name == "nn.bias_add" | ||
| assert ( | ||
| current_call.span is not None | ||
| and current_call.span.source_name.name == f"image_block{block}.conv" | ||
| ) | ||
| current_call = current_call.args[0] | ||
| assert current_call.op.name == "nn.conv2d" | ||
| assert ( | ||
| current_call.span is not None | ||
| and current_call.span.source_name.name == f"image_block{block}.conv" | ||
| ) | ||
| current_call = current_call.args[0] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.