Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def astext(self):
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Tuple index out of range")
return TupleGetItem(self.tuple_value, index)
return TupleGetItem(self.tuple_value, index, span=self.tuple_value.span)

def __len__(self):
return self.size
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,10 @@ def fill(self, sym):
return sym
elif isinstance(sym, np.ndarray):
return sym
elif not sym:
# some op conversion may return None
# e.g. op in frontend/pytorch.py: prim::device
return sym

raise RuntimeError(f"unsupported type {type(sym)}")

Expand Down
221 changes: 190 additions & 31 deletions python/tvm/relay/frontend/pytorch.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def add_quant_params(params, quant_params):
params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)


def inline_input_quant_params_for_fx(graph, params):
def inline_input_quant_params_for_fx(graph, params, param_debug_name_map):
"""
Canonicalize input scale and zero point access for FX-quantized graphs.
We expect input qparams to aten::quantize_per_tensor to be prim::Constant, but that's
Expand Down Expand Up @@ -568,7 +568,7 @@ def get_full_attr_name(current):
out_name = node.output().debugName()

if "_scale" in out_name or "_zero_point" in out_name:
full_attr = get_full_attr_name(node)
full_attr = param_debug_name_map[get_full_attr_name(node)]
assert full_attr in params, "%s not found in param dict." % full_attr
param_np = params[full_attr].numpy()
new_const_node = graph.create("prim::Constant")
Expand Down
24 changes: 19 additions & 5 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ def torch_version_check():

def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False, target="llvm"):
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(
script_module, input_shapes, keep_quantized_weight=keep_quantized_weight
)
with tvm.testing.disable_span_filling():
mod, params = relay.frontend.from_pytorch(
script_module, input_shapes, keep_quantized_weight=keep_quantized_weight
)
with tvm.testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_pytorch(
script_module, input_shapes, keep_quantized_weight=keep_quantized_weight
)
assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)

if keep_quantized_weight:
for p in params.values():
Expand Down Expand Up @@ -629,7 +635,11 @@ def pattern_table():

def run_qnn_mergecomposite(script_module, input_name, ishape):
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with tvm.testing.disable_span_filling():
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with tvm.testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_pytorch(script_module, input_shapes)
assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
pattern_table = get_pattern_table("test_table")
with tvm.transform.PassContext(opt_level=3):
pass_list = [
Expand Down Expand Up @@ -778,7 +788,11 @@ def forward(self, input):
script_module = torch.jit.trace(model_int8, fp32_input).eval()

input_infos = [("input", (fp32_input.shape, "float32"))]
mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
with tvm.testing.disable_span_filling():
mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
with tvm.testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_pytorch(script_module, input_infos)
assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
output = mod["main"].body

assert isinstance(output, relay.Tuple) and len(output) == 2
Expand Down
Loading