From 443580ed7809a8836ab5a5765a7870b77a122635 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Mon, 20 Feb 2023 12:37:36 +0800 Subject: [PATCH 1/3] [Relay][Frontend] Span Filling PyTorch - Construct debug name of C graph instruction as the source name of span for pytorch model. - To get the reference of renamed nodes. Add a function to export the converted C graph after conversion. - Add structural_equal comparisons with and without set_span to the existing test cases. - Add span test cases for frequent conversions. - Add span test case for exporting model parameter. --- python/tvm/relay/frontend/pytorch.py | 218 ++++++++++++-- tests/python/frontend/pytorch/qnn_test.py | 24 +- tests/python/frontend/pytorch/test_forward.py | 284 +++++++++++++++++- .../python/frontend/pytorch/test_fx_quant.py | 7 +- tests/python/frontend/pytorch/test_lstm.py | 6 +- .../frontend/pytorch/test_object_detection.py | 6 +- tests/python/frontend/pytorch/test_rnns.py | 16 +- 7 files changed, 515 insertions(+), 46 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 919ac65f504a..4cc4bd2a8b90 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -22,6 +22,7 @@ import functools import itertools import math +import re import sys import numpy as np @@ -44,6 +45,7 @@ from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated from .common import lstm_cell, try_infer_value, unbind, fold_constant +from .common import set_span from .pytorch_utils import is_version_greater_than, getattr_attr_name __all__ = ["from_pytorch"] @@ -135,11 +137,14 @@ def _is_int_seq(seq): class PyTorchOpConverter: """A helper class for holding PyTorch op converters.""" - def __init__(self, prelude, default_dtype): + def __init__(self, prelude, default_dtype, use_parser_friendly_name=False): self.prelude = prelude self.default_dtype = default_dtype self.create_convert_map() self.types = {} # map from nodes to (Relay) type annotations + self.source_map = {} # map from graph node to its source name + self.op_type_dict = {} # map from op type to its presenting order + self.use_parser_friendly_name = use_parser_friendly_name # this incrementally infers the type, see the comments on the type visitor # above. @@ -2391,11 +2396,16 @@ def nms(self, inputs, input_types): iou_threshold = inputs[2] # TVM NMS assumes score > 0 - scores = scores - _op.min(scores) + _op.const(1.0) + # - since there exists multi-comsumers for "scores", "num_boxes" + # - invoke set_span here to prevent expr-rewritten occurrs in span-filling stage + source_name = self._get_source_name_from_parameter(boxes) + scores = set_span(scores - _op.min(scores) + _op.const(1.0), source_name) - num_boxes = _op.shape_of(scores) + num_boxes = set_span(_op.shape_of(scores), source_name) # PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count - indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32") + # - since "arange" op will fill expr into its attribute + # - invoke set_span here to prevent expr-rewritten occurrs in span-filling stage + indices = _op.transform.arange(set_span(_op.squeeze(num_boxes), source_name), dtype="int32") indices = _op.expand_dims(indices, 0, 1) # Generate data with shape (1, num_anchors, 5) @@ -3869,7 +3879,12 @@ def report_missing_conversion(self, op_names): def convert_block(self, block, outputs): """Translate Torch "Block", used for prim::If and prim::Loop""" - ops = _get_operator_nodes(block.nodes()) + ops = _get_operator_nodes( + block.nodes(), + self.source_map, + self.op_type_dict, + self.use_parser_friendly_name, + ) ret_names = _get_input_names(block.returnNode()) return self.convert_operators(ops, outputs, ret_names) @@ -3940,13 +3955,19 @@ def get_var(name, val): actual_shape.append(Any()) else: actual_shape.append(dim) - return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) + expr = _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) else: - return _expr.var(name, type_annotation=checked_type) + expr = _expr.var(name, type_annotation=checked_type) + return set_span(expr, val.span) if val.span else expr return _expr.var(name) - loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) - loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] + source_name = self.source_map[loop_node] + loop_iter_var = set_span( + _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype), span=source_name + ) + loop_vars = set_span( + [get_var(name, val) for name, val in name_val_pairs[1:]], span=source_name + ) # Add non constant free variables to loop variables to prevent code blow up # Without this, if there are two for loops in a row, which often happens @@ -3969,7 +3990,7 @@ def get_var(name, val): prev_output = outputs[name] new_loop_var = get_var(name, prev_output) prev_outputs[name] = prev_output - outputs[name] = new_loop_var + outputs[name] = set_span(new_loop_var, source_name) loop_vars.append(new_loop_var) init_vals.append(prev_output) @@ -4021,7 +4042,10 @@ def convert_operators(self, operators, outputs, ret_names): if operator == "prim::Constant": outputs[node_name] = _get_constant(op_node) elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node): - outputs[node_name] = self.convert_to_list_adt(inputs) + outputs[node_name] = set_span( + self.convert_to_list_adt(inputs), + self.source_map[op_node], + ) elif operator == "prim::ListConstruct": # This assumes that no more elements will be appended to this list # In this case, we keep the Python list @@ -4038,25 +4062,31 @@ def _handel_nested_input(inputs): inputs_list.append(inputs[i]) return _expr.Tuple(inputs_list) - outputs[node_name] = _handel_nested_input(inputs) + outputs[node_name] = set_span( + _handel_nested_input(inputs), + self.source_map[op_node], + ) elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]: assert len(inputs) == 1 if isinstance(inputs[0], (list, _expr.TupleWrapper)): unpacked = inputs[0] else: - unpacked = _unpack_tuple(inputs[0]) + unpacked = set_span( + _unpack_tuple(inputs[0]), + self.source_map[op_node], + ) outputs.update(zip(_get_output_names(op_node), unpacked)) elif operator == "prim::prim::RaiseException": logger.warning("raising exceptions is ignored") outputs[node_name] = None elif operator == "prim::If": if_out = self.convert_if(op_node, outputs) - outputs[node_name] = if_out + outputs[node_name] = set_span(if_out, self.source_map[op_node]) elif operator == "prim::Loop": loop_out = self.convert_loop(op_node, outputs) unpacked_names = _get_output_names(op_node) assert len(loop_out) == len(unpacked_names) - outputs.update(zip(unpacked_names, loop_out)) + outputs.update(zip(unpacked_names, set_span(loop_out, self.source_map[op_node]))) else: if operator not in self.convert_map: # At this point, the only possible ops that are not in convert_map are @@ -4071,9 +4101,14 @@ def _handel_nested_input(inputs): else: relay_op = self.convert_map[operator] + self._set_parameter_source_name(op_node, outputs) relay_out = relay_op( - inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype) + # since the elements in "outputs" may change due to span-filling process + # we have to call "_get_op_inputs" again rather than use "inputs" directly + _get_op_inputs(op_node, outputs), + _get_input_types(op_node, outputs, default_dtype=self.default_dtype), ) + relay_out = set_span(relay_out, self.source_map[op_node]) self.record_output_type(relay_out) if isinstance(relay_out, tuple): @@ -4087,6 +4122,38 @@ def _handel_nested_input(inputs): return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] + def _set_parameter_source_name(self, op_node, outputs): + """A helper function to rewrite source_name of parameter.""" + for name in _get_input_names(op_node): + expr = outputs[name] + if isinstance(expr, (_expr.Var, _expr.Constant)): + name_sep = "_" if self.use_parser_friendly_name else "." + source_name = [self.source_map[op_node]] + if isinstance(expr, _expr.Var): + # variable name should have contained node source name + # for op with attributes in convert_params stage + # e.g. "aten::batch_norm_5.running_mean" + if expr.name_hint.startswith(source_name[0]): + source_name[0] = expr.name_hint + else: + source_name.append(expr.name_hint) + new_expr = set_span(expr, name_sep.join(source_name)) + outputs[name] = new_expr + + def _get_source_name_from_parameter(self, expr): + """A helper function to get source information of graph node from parameter.""" + if expr.span: + name_sep = "_" if self.use_parser_friendly_name else "." + source_name = expr.span.source_name.name + # discard variable/parameter name to get source_name of op node + # e.g. conv2d.w / conv2d_w -> conv2d + if isinstance(expr, _expr.Var): + postfix = f"{name_sep}{expr.name_hint}" + assert postfix in source_name + source_name = source_name[: -len(postfix)] + return source_name + return None + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" @@ -4354,13 +4421,67 @@ def _get_constant(node): return None -def _get_operator_nodes(nodes): +def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name): + """Rewrite debug name of node outputs with its operator type""" + + def _get_source_name(op_type): + op_idx = 0 + if op_type in op_type_dict: + op_idx = op_type_dict[op_type] + 1 + op_type_dict[op_type] = op_idx + return "_".join([op_type, str(op_idx)]) + + # get source name of operator and rename all of its outputs + # e.g. node.kind(): aten::adaptive_max_pool2d + # node_src_name -> aten::adaptive_max_pool2d_x + # output_1 -> aten::adaptive_max_pool2d_x_0 + # output_2 -> aten::adaptive_max_pool2d_x_1 + if node.kind() != "prim::GetAttr": + node_src_name = _get_source_name(node.kind()) + for index, output in enumerate(node.outputs()): + output.setDebugName("_".join([node_src_name, str(index)])) + # update source map + # if use_parser_friendly_name is True: e.g. prim::Constant_0 -> prim__Constant_0 + if use_parser_friendly_name: + node_src_name = re.sub(r":|\.", "_", node_src_name) + source_map[node] = node_src_name + + +def _debug_rename(graph, use_parser_friendly_name): + """Returns map between node and source name""" + source_map, op_type_dict = {}, {} + prim_with_blocks = ["prim::If", "prim::Loop"] + + def _traverse_graph(nodes): + for node in nodes: + if node.outputsSize() == 0: + continue + if node.kind() in prim_with_blocks: + for block in node.blocks(): + _traverse_graph(block.nodes()) + _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name) + + _traverse_graph(graph.nodes()) + return source_map + + +def _get_operator_nodes( + nodes, + source_map=None, + op_type_dict=None, + use_parser_friendly_name=False, +): """Returns torch IR nodes that need conversion to Relay""" - ops = [] + ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None + # Traverse nodes and add to graph for node in nodes: if node.outputsSize() == 0: continue + + if should_rename_graph: + _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name) + if node.outputsSize() > 1: node_name = "_".join(_get_output_names(node)) else: @@ -4531,7 +4652,7 @@ def terminate(users): return get_use_chains(root_getattr_node, terminate) -def convert_params(graph, state_dict, use_parser_friendly_name=False): +def convert_params(graph, state_dict, source_map, use_parser_friendly_name=False): """ Return Relay vars and TVM NDArrays for input parameters A chain of prim::GetAttr nodes is processed one at a time @@ -4553,17 +4674,25 @@ def convert_params(graph, state_dict, use_parser_friendly_name=False): full_attr = _getattr_full_name(getattrs, attr_name_sep) full_attr_node_name = _get_output_name(getattrs[-1]) + # set variable name by concatenating first consumer's name with full attribute + # e.g. "aten::batch_norm_5.running_mean" + var_name = attr_name_sep.join( + [ + source_map[_get_users(getattrs[-1])[0]], + full_attr.split(attr_name_sep)[-1], + ] + ) if full_attr.endswith("_packed_params"): # for quantized models packed_param_map[full_attr_node_name] = full_attr elif full_attr in state_dict: - if full_attr in vars_by_name: - var = vars_by_name[full_attr] + if var_name in vars_by_name: + var = vars_by_name[var_name] else: torch_tensor = state_dict[full_attr] - tensor, var = _get_tensor_and_var(torch_tensor, full_attr) - param_tensors[full_attr] = tensor - vars_by_name[full_attr] = var + tensor, var = _get_tensor_and_var(torch_tensor, var_name) + param_tensors[var_name] = tensor + vars_by_name[var_name] = var params[full_attr_node_name] = var return params, param_tensors, packed_param_map @@ -4581,6 +4710,19 @@ def get_all_op_names(graph): return set(node.kind() for node in nodes) +def export_c_graph(location, graph): + """Convert the graph to an onnx model and export it to the location.""" + import datetime + import os + + if not os.path.exists(location): + os.makedirs(location) + time_stamp = datetime.datetime.now().strftime("%m_%d_%Y_%H_%M_%S") + fname = os.path.join(location, "tvm_exported_c_graph_{}.txt".format(time_stamp)) + with open(f"{fname}", "w") as f: + f.write(str(graph)) + + def from_pytorch( script_module, input_infos, @@ -4588,6 +4730,7 @@ def from_pytorch( default_dtype="float32", use_parser_friendly_name=False, keep_quantized_weight=False, + export_renamed_c_graph_path=None, ): """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -4630,6 +4773,11 @@ def from_pytorch( we quantize weights in the frontend using a function that is equivalent to qnn.op.quantize(...) operating on Numpy arrays. + export_renamed_c_graph_path : str, optional + Export the renamed torch._C.Graph to the path. + During the conversion, variable names in torch._C.Graph are assigned based on op types. + The exported text file can be the reference to spans. + Returns ------- mod : tvm.IRModule @@ -4644,7 +4792,7 @@ def from_pytorch( prelude = Prelude(mod) enable_lower_all_tuples = True - converter = PyTorchOpConverter(prelude, default_dtype) + converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name) graph = script_module.graph.copy() @@ -4673,12 +4821,16 @@ def from_pytorch( new_names = [key.replace(".", "_") for key in params.keys()] params = dict(zip(new_names, params.values())) - param_vars, tensors, packed_param_map = convert_params(graph, params, use_parser_friendly_name) + # rename _C.Graph here for constructing meaningful source name of graph nodes + # by doing so, we could Use source_map as the reference to rename model parameters + source_map = _debug_rename(graph, use_parser_friendly_name) + param_vars, tensors, packed_param_map = convert_params( + graph, params, source_map, use_parser_friendly_name + ) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} outputs.update(param_vars) - ret_name = _get_input_names(graph.return_node()) # For quantized models quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"]) @@ -4698,7 +4850,14 @@ def from_pytorch( qnn_torch.add_quant_params(tvm_params, weight_quant_params) converter.update_convert_map(qnn_torch.convert_map) - outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name) + operator_nodes = _get_operator_nodes( + graph.nodes(), + converter.source_map, + converter.op_type_dict, + use_parser_friendly_name, + ) + ret_name = _get_input_names(graph.return_node()) + outputs = converter.convert_operators(operator_nodes, outputs, ret_name) # ListConstruct kept original python list. Convert to tuple. outputs = [_expr.Tuple(output) if isinstance(output, list) else output for output in outputs] @@ -4720,4 +4879,7 @@ def from_pytorch( mod["main"] = tvm.relay.Function(func_args, ret) + if export_renamed_c_graph_path: + export_c_graph(export_renamed_c_graph_path, graph) + return transform.RemoveUnusedFunctions()(mod), tvm_params diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index e9fbe12e9754..beaeeb999923 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -45,9 +45,15 @@ def torch_version_check(): def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False, target="llvm"): input_shapes = [(input_name, ishape)] - mod, params = relay.frontend.from_pytorch( - script_module, input_shapes, keep_quantized_weight=keep_quantized_weight - ) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_pytorch( + script_module, input_shapes, keep_quantized_weight=keep_quantized_weight + ) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_pytorch( + script_module, input_shapes, keep_quantized_weight=keep_quantized_weight + ) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) if keep_quantized_weight: for p in params.values(): @@ -629,7 +635,11 @@ def pattern_table(): def run_qnn_mergecomposite(script_module, input_name, ishape): input_shapes = [(input_name, ishape)] - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_pytorch(script_module, input_shapes) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) pattern_table = get_pattern_table("test_table") with tvm.transform.PassContext(opt_level=3): pass_list = [ @@ -778,7 +788,11 @@ def forward(self, input): script_module = torch.jit.trace(model_int8, fp32_input).eval() input_infos = [("input", (fp32_input.shape, "float32"))] - mod, _ = relay.frontend.from_pytorch(script_module, input_infos) + with tvm.testing.disable_span_filling(): + mod, _ = relay.frontend.from_pytorch(script_module, input_infos) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_pytorch(script_module, input_infos) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) output = mod["main"].body assert isinstance(output, relay.Tuple) and len(output) == 2 diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 033ce64b3ac4..a60b64fad5b5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -29,7 +29,8 @@ from tvm import relay from tvm.contrib import graph_executor from tvm.contrib.nvcc import have_fp16 -from tvm.contrib import cudnn +from tvm.contrib import cudnn, utils +from relay.utils.tag_span import _create_span, _set_span, _verify_structural_equal_with_span import torch from torch.nn import Module @@ -135,6 +136,7 @@ def verify_model( kind="graph", check_correctness=True, cpu_only=False, + validate_structural_equal=True, ): """Assert that the output of a compiled model matches with that of its baseline.""" @@ -175,7 +177,13 @@ def verify_model( input_names = [f"input{idx}" for idx, _ in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) - mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + if validate_structural_equal: + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + for arg in mod["main"].params[: len(input_names)]: assert arg.name_hint in input_names compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) @@ -231,6 +239,7 @@ def verify_model_with_input( rtol=1e-5, atol=1e-5, assert_shape_only=False, + validate_structural_equal=True, ): """Generic function to generate and compare Pytorch and TVM output""" input_dict = input_dict or {} @@ -239,7 +248,13 @@ def verify_model_with_input( trace = torch.jit.trace(test_func, [input.clone() for input in input_data]) input_names = [f"input{idx}" for idx, _ in enumerate(input_data)] input_shapes = list(zip(input_names, [inp.shape for inp in input_data])) - mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + if validate_structural_equal: + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + with tvm.transform.PassContext(opt_level=3): for target in ["llvm", "cuda"]: if not tvm.runtime.enabled(target): @@ -257,6 +272,20 @@ def verify_model_with_input( tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=rtol, atol=atol) +def gen_ir_module(model, inputs, use_parser_friendly_name=False): + """Helper function to generate IRModule with meaningful source information""" + + trace = torch.jit.trace(model, inputs) + input_names = ["input{}".format(idx) for idx, _ in enumerate(inputs)] + input_shapes = list(zip(input_names, [inp.shape for inp in inputs])) + mod, _ = relay.frontend.from_pytorch( + trace, + input_shapes, + use_parser_friendly_name=use_parser_friendly_name, + ) + return mod + + # Single operator tests @tvm.testing.uses_gpu def test_forward_pixel_shuffle(): @@ -2596,7 +2625,11 @@ def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=None) input_data = [torch.randn(shape, dtype=idtype) for shape in ishapes] # Compile via VM - mod, params = relay.frontend.from_pytorch(input_model, input_shapes) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_pytorch(input_model, input_shapes) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_pytorch(input_model, input_shapes) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) for tgt in targets: if not tvm.testing.device_enabled(tgt): @@ -3951,7 +3984,8 @@ def forward(self, x): def test_weight_names(): tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)]) _, params = relay.frontend.from_pytorch(tm, [("input", (2, 3))]) - assert set(params.keys()) == set(n for n, _ in tm.named_parameters()) + keys = [key.split(".")[-1] for key in params.keys()] + assert set(keys) == set(n for n, p in tm.named_parameters()) @tvm.testing.uses_gpu @@ -4320,12 +4354,12 @@ def test_randn(): def test_func(): return torch.randn([1, 3, 10, 10]) - verify_model_with_input(test_func, [], assert_shape_only=True) + verify_model_with_input(test_func, [], assert_shape_only=True, validate_structural_equal=False) def test_func1(): return torch.randn(1, 3, 10, 10) - verify_model_with_input(test_func1, [], assert_shape_only=True) + verify_model_with_input(test_func1, [], assert_shape_only=True, validate_structural_equal=False) def test_forward_pretrained_bert_base_uncased(): @@ -5102,18 +5136,25 @@ def _test_trilu(op, diagonal): def test_multinomial(): + """test_multinomial""" + def _test_multinomial(num_samples): return lambda inp: torch.multinomial(inp, num_samples=num_samples, replacement=True) # Dont check output since it's random. Instead we'll just make sure shapes are right. verify_model( - _test_multinomial(2), [torch.rand(size=[3]).float()], cpu_only=True, check_correctness=False + _test_multinomial(2), + [torch.rand(size=[3]).float()], + cpu_only=True, + check_correctness=False, + validate_structural_equal=False, ) verify_model( _test_multinomial(1), [torch.rand(size=[4, 5]).float()], cpu_only=True, check_correctness=False, + validate_structural_equal=False, ) @@ -5155,5 +5196,232 @@ def test_fn(alpha, beta): verify_model(test_fn(0.5, 1.0), [M, batch1, batch2]) +def test_exporting_renamed_c_graph(): + """test exproting model when export_renamed_model is set""" + + # model definition + class Conv2D(Module): + def __init__(self): + super(Conv2D, self).__init__() + self.conv = torch.nn.Conv2d(3, 6, 3, bias=True) + + def forward(self, *args): + return self.conv(args[0]) + + input_name, input_shape = "input", [1, 3, 10, 10] + shape_list = [(input_name, input_shape)] + temp_dir = utils.tempdir().path + script_module = torch.jit.trace(Conv2D(), [torch.rand(input_shape)]) + _, _ = relay.frontend.from_pytorch( + script_module, shape_list, export_renamed_c_graph_path=temp_dir + ) + + exported_c_graph_name = os.listdir(temp_dir)[0] + assert "tvm_exported_c_graph_" in exported_c_graph_name + + # make sure the renamed output variable presents in the restored _C.Graph + with open(f"{temp_dir}/{exported_c_graph_name}", "r") as f: + graph = f.read() + assert "%aten::_convolution_0" in graph + + +class TestSetSpan: + """test structural equal between translated / hand-crafted relay IR with span tagged.""" + + def _verify(self, res_fptr, golden_fptr): + with tvm.testing.enable_span_filling(): + with_span = res_fptr() + with tvm.testing.disable_span_filling(): + without_span = res_fptr() + assert tvm.ir.structural_equal(with_span, without_span) + _verify_structural_equal_with_span(with_span, golden_fptr()) + + def test_conv2d_bias_add(self): + ker_sz, in_chs, out_chs = 7, 3, 6 + input_shape = [1, 3, 10, 10] + + def _res(): + # model definition + class Conv2D(Module): + def __init__(self): + super(Conv2D, self).__init__() + self.conv = torch.nn.Conv2d(in_chs, out_chs, ker_sz, bias=True) + + def forward(self, *args): + return self.conv(args[0]) + + # get frontend model + mod = gen_ir_module(Conv2D(), [torch.rand(input_shape)]) + return mod["main"] + + def _golden(): + conv_si = "aten::_convolution_0" + input_name = "input0" + input_0 = relay.var( + input_name, + shape=tuple(input_shape), + span=_create_span(f"{conv_si}.{input_name}"), + ) + weight_name = f"{conv_si}.weight" + conv_weight = relay.var( + weight_name, + shape=(out_chs, in_chs, ker_sz, ker_sz), + span=_create_span(weight_name), + ) + bias_name = f"{conv_si}.bias" + conv_bias = relay.var( + bias_name, + shape=(out_chs,), + span=_create_span(bias_name), + ) + conv_out = _set_span( + relay.nn.conv2d( + input_0, + conv_weight, + padding=[0] * 4, + channels=out_chs, + kernel_size=[ker_sz] * 2, + ), + conv_si, + ) + bias_out = _set_span(relay.nn.bias_add(conv_out, conv_bias), conv_si) + return relay.Function([input_0, conv_weight, conv_bias], bias_out) + + self._verify(_res, _golden) + + def test_batchnorm_span(self): + features = 16 + input_shape = [1, 16, 10, 10] + + def _res(): + # model definition + bn_2d = torch.nn.BatchNorm2d(features) + + # get frontend model + mod = gen_ir_module(bn_2d, [torch.rand(input_shape)]) + return mod["main"] + + def _golden(): + bn_si = "aten::batch_norm_0" + input_name = "input0" + input_0 = relay.var( + input_name, + shape=tuple(input_shape), + span=_create_span(f"{bn_si}.{input_name}"), + ) + weight_name = f"{bn_si}.weight" + bn_weight = relay.var( + weight_name, + shape=(features,), + span=_create_span(weight_name), + ) + bias_name = f"{bn_si}.bias" + bn_bias = relay.var( + bias_name, + shape=(features,), + span=_create_span(bias_name), + ) + rm_name = f"{bn_si}.running_mean" + bn_rm = relay.var( + rm_name, + shape=(features,), + span=_create_span(rm_name), + ) + rv_name = f"{bn_si}.running_var" + bn_rv = relay.var( + rv_name, + shape=(features,), + span=_create_span(rv_name), + ) + bn_out = _set_span( + relay.nn.batch_norm(input_0, bn_weight, bn_bias, bn_rm, bn_rv), + bn_si, + ) + bn_tuple_get_item = _set_span(relay.TupleGetItem(bn_out.tuple_value, 0), bn_si) + return relay.Function([input_0, bn_weight, bn_bias, bn_rm, bn_rv], bn_tuple_get_item) + + self._verify(_res, _golden) + + def test_reshape_span(self): + input_shape = [2, 1, 10, 1, 10] + new_shape = [2, 1, 10, 10] + + def _res(): + # model definition + class Reshape(Module): + def forward(self, *args): + return args[0].reshape(new_shape) + + # get frontend model + mod = gen_ir_module(Reshape(), [torch.rand(input_shape)]) + return mod["main"] + + def _golden(): + reshape_si = "aten::reshape_0" + input_name = "input0" + input_0 = relay.var( + input_name, + shape=tuple(input_shape), + span=_create_span(f"{reshape_si}.{input_name}"), + ) + reshape_out = _set_span( + relay.reshape(input_0, newshape=new_shape), + reshape_si, + ) + return relay.Function([input_0], reshape_out) + + self._verify(_res, _golden) + + def test_dense_bias_add(self): + in_f, out_f = 10, 7 + input_shape = [in_f, in_f] + + def _res(): + # model definition + class Dense(Module): + def __init__(self): + super(Dense, self).__init__() + self.linear = torch.nn.Linear(in_f, out_f, bias=True) + + def forward(self, *args): + return self.linear(args[0]) + + # get frontend model + mod = gen_ir_module(Dense(), [torch.rand(input_shape)]) + return mod["main"] + + def _golden(): + dense_si = "aten::linear_0" + input_name = "input0" + input_0 = relay.var( + input_name, + shape=tuple(input_shape), + span=_create_span(f"{dense_si}.{input_name}"), + ) + weight_name = f"{dense_si}.weight" + dense_weight = relay.var( + weight_name, + shape=(out_f, in_f), + span=_create_span(weight_name), + ) + bias_name = f"{dense_si}.bias" + dense_bias = relay.var( + bias_name, + shape=(out_f,), + span=_create_span(bias_name), + ) + dense_out = _set_span( + relay.nn.dense(input_0, dense_weight), + dense_si, + ) + bias_out = _set_span( + relay.nn.bias_add(dense_out, dense_bias, axis=-1), + dense_si, + ) + return relay.Function([input_0, dense_weight, dense_bias], bias_out) + + self._verify(_res, _golden) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/frontend/pytorch/test_fx_quant.py b/tests/python/frontend/pytorch/test_fx_quant.py index f35094a83137..564900cbf209 100644 --- a/tests/python/frontend/pytorch/test_fx_quant.py +++ b/tests/python/frontend/pytorch/test_fx_quant.py @@ -23,6 +23,7 @@ from torchvision.models.efficientnet import efficientnet_b4 from torchvision.models.resnet import resnet50 from tvm import relay +import tvm.testing def quantize(model): @@ -38,7 +39,11 @@ def quantize_and_build(model, in_size): with torch.no_grad(): script_module = torch.jit.trace(qmodel, inp) - mod, _ = relay.frontend.from_pytorch(script_module, [(input_name, inp.shape)]) + with tvm.testing.disable_span_filling(): + mod, _ = relay.frontend.from_pytorch(script_module, [(input_name, inp.shape)]) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_pytorch(script_module, [(input_name, inp.shape)]) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) mod = relay.transform.InferType()(mod) # Make sure that the model is quantized diff --git a/tests/python/frontend/pytorch/test_lstm.py b/tests/python/frontend/pytorch/test_lstm.py index 25d4563ee64e..e9dd2b380c1e 100644 --- a/tests/python/frontend/pytorch/test_lstm.py +++ b/tests/python/frontend/pytorch/test_lstm.py @@ -337,7 +337,11 @@ def test_custom_lstm(): for (name, raw_model, states, input_shapes) in models: script_module = torch.jit.script(raw_model) - mod, params = from_pytorch(script_module, input_shapes) + with tvm.testing.disable_span_filling(): + mod, params = from_pytorch(script_module, input_shapes) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = from_pytorch(script_module, input_shapes) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) with torch.no_grad(): pt_result = raw_model(inp.clone(), states) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 83b13f686be2..25e784b00a1b 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -104,7 +104,11 @@ def test_detection_models(): shape_list = [(input_name, input_shape)] scripted_model = generate_jit_model(1) - mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_pytorch(scripted_model, shape_list) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) data = process_image(img) data_np = data.detach().numpy() diff --git a/tests/python/frontend/pytorch/test_rnns.py b/tests/python/frontend/pytorch/test_rnns.py index fba55b9c4c8f..3ea423250010 100644 --- a/tests/python/frontend/pytorch/test_rnns.py +++ b/tests/python/frontend/pytorch/test_rnns.py @@ -456,7 +456,15 @@ def get_onnx_model(model): traced_script_module = torch.jit.trace(model, dummy_inputs[0]).eval() # Import model to Relay - mod, params = relay.frontend.from_pytorch(traced_script_module, shape_desc) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_pytorch( + traced_script_module, shape_desc + ) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_pytorch( + traced_script_module, shape_desc + ) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) elif format == "onnx": try: onnx_model = get_onnx_model(model) @@ -468,7 +476,11 @@ def get_onnx_model(model): continue # Import model to Relay - mod, params = relay.frontend.from_onnx(onnx_model, shape_desc) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_onnx(onnx_model, shape_desc) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_onnx(onnx_model, shape_desc) + assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) # Model compilation by tvm with tvm.transform.PassContext(opt_level=3): From c7cb7e21d98d9669d9987c5c7a5fd94a9b3ca881 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Fri, 24 Feb 2023 09:43:05 +0800 Subject: [PATCH 2/3] [SpanFillingPyTorch] - Return TupleGetItem expr from TupleWrapper with the span of its Tuple. - Add None type symbol in set sapn for certain conversion. - Add current_op member varible to PyTorchOpConverter to track which op is converting for pytorch frontend. --- python/tvm/relay/expr.py | 2 +- python/tvm/relay/frontend/common.py | 4 +++ python/tvm/relay/frontend/pytorch.py | 40 ++++++++++++---------------- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index cb14552ac16e..d8bca5c4a431 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -605,7 +605,7 @@ def astext(self): def __getitem__(self, index): if index >= len(self): raise IndexError("Tuple index out of range") - return TupleGetItem(self.tuple_value, index) + return TupleGetItem(self.tuple_value, index, span=self.tuple_value.span) def __len__(self): return self.size diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 5d3b0a334590..39e17b27da2a 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -1169,6 +1169,10 @@ def fill(self, sym): return sym elif isinstance(sym, np.ndarray): return sym + elif not sym: + # some op conversion may return None + # e.g. op in frontend/pytorch.py: prim::device + return sym raise RuntimeError(f"unsupported type {type(sym)}") diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 4cc4bd2a8b90..0855a6ff462d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -144,6 +144,7 @@ def __init__(self, prelude, default_dtype, use_parser_friendly_name=False): self.types = {} # map from nodes to (Relay) type annotations self.source_map = {} # map from graph node to its source name self.op_type_dict = {} # map from op type to its presenting order + self.current_op = [] # stack for recording current processing op self.use_parser_friendly_name = use_parser_friendly_name # this incrementally infers the type, see the comments on the type visitor @@ -337,7 +338,10 @@ def arange(self, inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): - inp = _op.cast(val, dtype) + # since "arange" op will fill expr into its attribute + # invoke set_span here to prevent expr-rewritten occurrs in span-filling stage + source_name = self.source_map[self.current_op[-1]] + inp = set_span(_op.cast(val, dtype), source_name) ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, dtype)) else: ret = _create_typed_const(val, dtype) @@ -2398,7 +2402,7 @@ def nms(self, inputs, input_types): # TVM NMS assumes score > 0 # - since there exists multi-comsumers for "scores", "num_boxes" # - invoke set_span here to prevent expr-rewritten occurrs in span-filling stage - source_name = self._get_source_name_from_parameter(boxes) + source_name = self.source_map[self.current_op[-1]] scores = set_span(scores - _op.min(scores) + _op.const(1.0), source_name) num_boxes = set_span(_op.shape_of(scores), source_name) @@ -4038,6 +4042,9 @@ def convert_operators(self, operators, outputs, ret_names): for node_name, op_node in operators: operator = op_node.kind() inputs = _get_op_inputs(op_node, outputs) + # we need to record what current operator is to provide correct source name + # for operators needed to be taken care with (e.g. nms / arange ...) + self.current_op.append(op_node) if operator == "prim::Constant": outputs[node_name] = _get_constant(op_node) @@ -4071,11 +4078,10 @@ def _handel_nested_input(inputs): if isinstance(inputs[0], (list, _expr.TupleWrapper)): unpacked = inputs[0] else: - unpacked = set_span( - _unpack_tuple(inputs[0]), - self.source_map[op_node], - ) - outputs.update(zip(_get_output_names(op_node), unpacked)) + unpacked = _unpack_tuple(inputs[0]) + outputs.update( + zip(_get_output_names(op_node), set_span(unpacked, self.source_map[op_node])) + ) elif operator == "prim::prim::RaiseException": logger.warning("raising exceptions is ignored") outputs[node_name] = None @@ -4120,6 +4126,8 @@ def _handel_nested_input(inputs): assert op_node.outputsSize() == 1 outputs[node_name] = relay_out + self.current_op.pop() + return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] def _set_parameter_source_name(self, op_node, outputs): @@ -4140,20 +4148,6 @@ def _set_parameter_source_name(self, op_node, outputs): new_expr = set_span(expr, name_sep.join(source_name)) outputs[name] = new_expr - def _get_source_name_from_parameter(self, expr): - """A helper function to get source information of graph node from parameter.""" - if expr.span: - name_sep = "_" if self.use_parser_friendly_name else "." - source_name = expr.span.source_name.name - # discard variable/parameter name to get source_name of op node - # e.g. conv2d.w / conv2d_w -> conv2d - if isinstance(expr, _expr.Var): - postfix = f"{name_sep}{expr.name_hint}" - assert postfix in source_name - source_name = source_name[: -len(postfix)] - return source_name - return None - def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" @@ -4775,8 +4769,8 @@ def from_pytorch( export_renamed_c_graph_path : str, optional Export the renamed torch._C.Graph to the path. - During the conversion, variable names in torch._C.Graph are assigned based on op types. - The exported text file can be the reference to spans. + During the conversion, variable names in torch._C.Graph will be assigned based on their op + types. The exported text file can be the reference to spans. Returns ------- From 16c582f4f58ef2af506203becb45825172b701b6 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Wed, 1 Mar 2023 09:49:30 +0800 Subject: [PATCH 3/3] [SpanFillingPyTorch] - Fix the error caused by the quantized params not found after renaming the debug name of C graph. --- python/tvm/relay/frontend/pytorch.py | 9 ++++++--- python/tvm/relay/frontend/qnn_torch.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0855a6ff462d..0bf61373e3ef 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4655,6 +4655,7 @@ def convert_params(graph, state_dict, source_map, use_parser_friendly_name=False params = {} param_tensors = {} packed_param_map = {} + param_debug_name_map = {} vars_by_name = {} seen = set() attr_name_sep = "_" if use_parser_friendly_name else "." @@ -4686,10 +4687,12 @@ def convert_params(graph, state_dict, source_map, use_parser_friendly_name=False torch_tensor = state_dict[full_attr] tensor, var = _get_tensor_and_var(torch_tensor, var_name) param_tensors[var_name] = tensor + # for quantized parameters to be correctly located + param_debug_name_map[full_attr_node_name] = var_name vars_by_name[var_name] = var params[full_attr_node_name] = var - return params, param_tensors, packed_param_map + return params, param_tensors, packed_param_map, param_debug_name_map def get_all_op_names(graph): @@ -4818,7 +4821,7 @@ def from_pytorch( # rename _C.Graph here for constructing meaningful source name of graph nodes # by doing so, we could Use source_map as the reference to rename model parameters source_map = _debug_rename(graph, use_parser_friendly_name) - param_vars, tensors, packed_param_map = convert_params( + param_vars, tensors, packed_param_map, param_debug_name_map = convert_params( graph, params, source_map, use_parser_friendly_name ) @@ -4832,7 +4835,7 @@ def from_pytorch( weight_quant_params = qnn_torch.get_weight_quant_params( script_module, packed_param_map.values() ) - qnn_torch.inline_input_quant_params_for_fx(graph, tensors) + qnn_torch.inline_input_quant_params_for_fx(graph, tensors, param_debug_name_map) input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph) qnn_torch.add_quant_params_to_outputs( outputs, diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index a4eb56c1048a..131a471fd5c3 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -534,7 +534,7 @@ def add_quant_params(params, quant_params): params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias) -def inline_input_quant_params_for_fx(graph, params): +def inline_input_quant_params_for_fx(graph, params, param_debug_name_map): """ Canonicalize input scale and zero point access for FX-quantized graphs. We expect input qparams to aten::quantize_per_tensor to be prim::Constant, but that's @@ -568,7 +568,7 @@ def get_full_attr_name(current): out_name = node.output().debugName() if "_scale" in out_name or "_zero_point" in out_name: - full_attr = get_full_attr_name(node) + full_attr = param_debug_name_map[get_full_attr_name(node)] assert full_attr in params, "%s not found in param dict." % full_attr param_np = params[full_attr].numpy() new_const_node = graph.create("prim::Constant")